def get_binned_spikes(spike_times, spike_clusters, cluster_id, epoch_time, pre_time=0.2,post_time=0.5, bin_size=0.025, smoothing=0.025, return_fr=True): binned_firing_rate = calculate_peths( spike_times, spike_clusters, cluster_id, epoch_time, pre_time=pre_time,post_time=post_time, bin_size=bin_size, smoothing=smoothing, return_fr=return_fr)[1] return binned_firing_rate
def compute_peth(self, trial_type, clust, trials_id): peths, bin = calculate_peths(self.spikes.times, self.spikes.clusters, [self.clust_ids[clust]], self.trials[trial_type][trials_id], self.t_before, self.t_after) peth_mean = peths.means[0, :] peth_std = peths.stds[0, :] / np.sqrt(len(trials_id)) t_peth = peths.tscale return t_peth, peth_mean, peth_std
def test_peths_synthetic(self): n_spikes = 20000 n_clusters = 20 n_events = 200 record_length = 1654 cluster_sel = [1, 2, 3, 6, 15, 16] np.random.seed(seed=42) spike_times = np.sort(np.random.rand(n_spikes, ) * record_length) spike_clusters = np.random.randint(0, n_clusters, n_spikes) event_times = np.sort(np.random.rand(n_events, ) * record_length) peth, fr = calculate_peths(spike_times, spike_clusters, cluster_ids=cluster_sel, align_times=event_times) self.assertTrue(peth.means.shape[0] == len(cluster_sel)) self.assertTrue(np.all(peth.means.shape == peth.stds.shape)) self.assertTrue(np.all(fr.shape == (n_events, len(cluster_sel), 28))) self.assertTrue(peth.tscale.size == 28)
import matplotlib.pyplot as plt import numpy as np import alf.io from brainbox.singlecell import calculate_peths from oneibl.one import ONE one = ONE() eid = one.search(subject='KS004', date=['2019-09-25'], task_protocol='ephysChoiceWorld')[0] datasets = one.load(eid, download_only=True) ses_path = datasets[0].local_path.parent spikes = alf.io.load_object(ses_path, 'spikes') trials = alf.io.load_object(ses_path, '_ibl_trials') peth, bs = calculate_peths(spikes.times, spikes.clusters, [225, 52], trials.goCue_times) plt.plot(peth.tscale, peth.means.T) for m in np.arange(peth.means.shape[0]): plt.fill_between(peth.tscale, peth.means[m, :].T - peth.stds[m, :].T / 20, peth.means[m, :].T + peth.stds[m, :].T / 20, alpha=0.2, edgecolor='#1B2ACC', facecolor='#089FFF', linewidth=4, linestyle='dashdot', antialiased=True)
for i, align_event in enumerate(align_events): if align_event == 'stimOff': if choice == stim_side: offset = 1.0 else: offset = 2.0 align_times = trials['feedback_times'][trial_ids[d]] + offset elif align_event == 'movement': raise NotImplementedError else: align_times = trials[align_event + '_times'][trial_ids[d]] peth_, bs = calculate_peths(spikes['times'], spikes['clusters'], cluster_ids, align_times, pre_time=PRE_TIME, post_time=POST_TIME, bin_size=BIN_SIZE, smoothing=SMOOTH_SIZE) peth_means[d][align_event] = peth_.means peth_stds[d][align_event] = peth_.stds binned[d][align_event] = bs # plot peths for each cluster n_trials, n_clusters, _ = binned[d][align_event].shape n_rows = 4 # clusters per page n_plots = int(np.ceil(n_clusters / n_rows)) # n_plots = 3 # for testing for d in ['l', 'r']: for p in range(n_plots):
def peri_event_time_histogram( spike_times, spike_clusters, events, cluster_id, # Everything you need for a basic plot t_before=0.2, t_after=0.5, bin_size=0.025, smoothing=0.025, as_rate=True, include_raster=False, n_rasters=None, error_bars='std', ax=None, pethline_kwargs={ 'color': 'blue', 'lw': 2 }, errbar_kwargs={ 'color': 'blue', 'alpha': 0.5 }, eventline_kwargs={ 'color': 'black', 'alpha': 0.5 }, raster_kwargs={ 'color': 'black', 'lw': 0.5 }, **kwargs): """ Plot peri-event time histograms, with the meaning firing rate of units centered on a given series of events. Can optionally add a raster underneath the PETH plot of individual spike trains about the events. Parameters ---------- spike_times : array_like Spike times (in seconds) spike_clusters : array-like Cluster identities for each element of spikes events : array-like Times to align the histogram(s) to cluster_id : int Identity of the cluster for which to plot a PETH t_before : float, optional Time before event to plot (default: 0.2s) t_after : float, optional Time after event to plot (default: 0.5s) bin_size :float, optional Width of bin for histograms (default: 0.025s) smoothing : float, optional Sigma of gaussian smoothing to use in histograms. (default: 0.025s) as_rate : bool, optional Whether to use spike counts or rates in the plot (default: `True`, uses rates) include_raster : bool, optional Whether to put a raster below the PETH of individual spike trains (default: `False`) n_rasters : int, optional If include_raster is True, the number of rasters to include. If `None` will default to plotting rasters around all provided events. (default: `None`) error_bars : {'std', 'sem', 'none'}, optional Defines which type of error bars to plot. Options are: -- `'std'` for 1 standard deviation -- `'sem'` for standard error of the mean -- `'none'` for only plotting the mean value (default: `'std'`) ax : matplotlib axes, optional If passed, the function will plot on the passed axes. Note: current behavior causes whatever was on the axes to be cleared before plotting! (default: `None`) pethline_kwargs : dict, optional Dict containing line properties to define PETH plot line. Default is a blue line with weight of 2. Needs to have color. See matplotlib plot documentation for more options. (default: `{'color': 'blue', 'lw': 2}`) errbar_kwargs : dict, optional Dict containing fill-between properties to define PETH error bars. Default is a blue fill with 50 percent opacity.. Needs to have color. See matplotlib fill_between documentation for more options. (default: `{'color': 'blue', 'alpha': 0.5}`) eventline_kwargs : dict, optional Dict containing fill-between properties to define line at event. Default is a black line with 50 percent opacity.. Needs to have color. See matplotlib vlines documentation for more options. (default: `{'color': 'black', 'alpha': 0.5}`) raster_kwargs : dict, optional Dict containing properties defining lines in the raster plot. Default is black lines with line width of 0.5. See matplotlib vlines for more options. (default: `{'color': 'black', 'lw': 0.5}`) Returns ------- ax : matplotlib axes Axes with all of the plots requested. """ # Check to make sure if we fail, we fail in an informative way if not len(spike_times) == len(spike_clusters): raise ValueError('Spike times and clusters are not of the same shape') if len(events) == 1: raise ValueError('Cannot make a PETH with only one event.') if error_bars not in ('std', 'sem', 'none'): raise ValueError('Invalid error bar type was passed.') if not all(np.isfinite(events)): raise ValueError( 'There are NaN or inf values in the list of events passed. ' ' Please remove non-finite data points and try again.') # Compute peths peths, binned_spikes = singlecell.calculate_peths( spike_times, spike_clusters, [cluster_id], events, t_before, t_after, bin_size, smoothing, as_rate) # Construct an axis object if none passed if ax is None: plt.figure() ax = plt.gca() # Plot the curve and add error bars mean = peths.means[0, :] ax.plot(peths.tscale, mean, **pethline_kwargs) if error_bars == 'std': bars = peths.stds[0, :] elif error_bars == 'sem': bars = peths.stds[0, :] / np.sqrt(len(events)) else: bars = np.zeros_like(mean) if error_bars != 'none': ax.fill_between(peths.tscale, mean - bars, mean + bars, **errbar_kwargs) # Plot the event marker line. Extends to 5% higher than max value of means plus any error bar. plot_edge = (mean.max() + bars[mean.argmax()]) * 1.05 ax.vlines(0., 0., plot_edge, **eventline_kwargs) # Set the limits on the axes to t_before and t_after. Either set the ylim to the 0 and max # values of the PETH, or if we want to plot a spike raster below, create an equal amount of # blank space below the zero where the raster will go. ax.set_xlim([-t_before, t_after]) ax.set_ylim([-plot_edge if include_raster else 0., plot_edge]) # Put y ticks only at min, max, and zero if mean.min() != 0: ax.set_yticks([0, mean.min(), mean.max()]) else: ax.set_yticks([0., mean.max()]) # Move the x axis line from the bottom of the plotting space to zero if including a raster, # Then plot the raster if include_raster: if n_rasters is None: n_rasters = len(events) if n_rasters > 60: warn( "Number of raster traces is greater than 60. This might look bad on the plot." ) ax.axhline(0., color='black') tickheight = plot_edge / len( events[:n_rasters]) # How much space per trace tickedges = np.arange(0., -plot_edge - 1e-5, -tickheight) clu_spks = spike_times[spike_clusters == cluster_id] for i, t in enumerate(events[:n_rasters]): idx = np.bitwise_and(clu_spks >= t - t_before, clu_spks <= t + t_after) event_spks = clu_spks[idx] ax.vlines(event_spks - t, tickedges[i + 1], tickedges[i], **raster_kwargs) ax.set_ylabel('Firing Rate' if as_rate else 'Number of spikes', y=0.75) else: ax.set_ylabel('Firing Rate' if as_rate else 'Number of spikes') ax.spines['top'].set_visible(False) ax.spines['right'].set_visible(False) ax.set_xlabel('Time (s) after event') return ax
def plot_grating_figures( session_path, cluster_ids_summary, cluster_ids_selected, save_dir=None, format='png', pre_time=0.5, post_time=2.5, bin_size=0.005, smoothing=0.025, n_rand_clusters=20, plot_summary=True, plot_selected=True): """ Produces two summary figures for the oriented grating protocol; the first summary figure contains plots that compare different measures during the first and second grating protocols, such as orientation selectivity index (OSI), orientation preference, fraction of visual clusters, PSTHs, firing rate histograms, etc. The second summary figure contains plots of polar PSTHs and corresponding rasters for a random subset of visually responsive clusters. Parameters ---------- session_path : str absolute path to experimental session directory cluster_ids_summary : list the clusters for which to plot summary psths/rasters; if empty, all clusters with responses during the grating presentations are used cluster_ids_selected : list the clusters for which to plot individual psths/rasters; if empty, `n_rand_clusters` are randomly chosen save_dir : str or NoneType if NoneType, figures are displayed; else a string defining the absolute filepath to the directory in which figures will be saved format : str file format, i.e. 'png' | 'pdf' | 'jpg' pre_time : float time (sec) to plot before grating presentation onset post_time : float time (sec) to plot after grating presentation onset (should include length of stimulus) bin_size : float size of bins for raster plots/psths smoothing : float size of smoothing kernel (sec) n_rand_clusters : int the number of random clusters to choose for which to plot psths/rasters if `cluster_ids_slected` is empty plot_summary : bool a flag for plotting the summary figure plot_selected : bool a flag for plotting the selected units figure Returns ------- metrics : dict - 'osi' (dict): keys 'beg', 'end' point to arrays of osis during these epochs - 'orientation_pref' (dict): keys 'beg', 'end' point to arrays of orientation preference - 'frac_resp_by_depth' (dict): fraction of responsive clusters by depth fig_dict : dict A dict whose values are handles to one or both figures generated. """ fig_dict = {} cluster_ids = cluster_ids_summary cluster_idxs = cluster_ids_selected epochs = ['beg', 'end'] # ------------------------- # load required alf objects # ------------------------- print('loading alf objects...', end='', flush=True) spikes = ioalf.load_object(session_path, 'spikes') clusters = ioalf.load_object(session_path, 'clusters') gratings = ioalf.load_object(session_path, '_iblcertif_.odsgratings') spontaneous = ioalf.load_object(session_path, '_iblcertif_.spontaneous') grating_times = { 'beg': gratings['odsgratings.times.00'], 'end': gratings['odsgratings.times.01']} grating_vals = { 'beg': gratings['odsgratings.stims.00'], 'end': gratings['odsgratings.stims.01']} spont_times = { 'beg': spontaneous['spontaneous.times.00'], 'end': spontaneous['spontaneous.times.01']} # -------------------------- # calculate relevant metrics # -------------------------- print('calcuating mean responses to gratings...', end='', flush=True) # calculate mean responses to gratings mask_clust = np.isin(spikes.clusters, cluster_ids) # update mask for responsive clusters mask_times = np.full(spikes.times.shape, fill_value=False) for epoch in epochs: mask_times |= (spikes.times >= grating_times[epoch].min()) & \ (spikes.times <= grating_times[epoch].max()) resp = {epoch: [] for epoch in epochs} for epoch in epochs: resp[epoch] = are_neurons_responsive( spikes.times[mask_clust], spikes.clusters[mask_clust], grating_times[epoch], grating_vals[epoch], spont_times[epoch]) responses = {epoch: [] for epoch in epochs} for epoch in epochs: responses[epoch] = bin_responses( spikes.times[mask_clust], spikes.clusters[mask_clust], grating_times[epoch], grating_vals[epoch]) responses_mean = {epoch: np.mean(responses[epoch], axis=2) for epoch in epochs} # responses_se = {epoch: np.std(responses[epoch], axis=2) / np.sqrt(responses[epoch].shape[2]) # for epoch in responses.keys()} print('done') # calculate osi and orientation preference print('calcuating osi/orientation preference...', end='', flush=True) ori_pref = {epoch: [] for epoch in epochs} osi = {epoch: [] for epoch in epochs} for epoch in epochs: osi[epoch], ori_pref[epoch] = compute_selectivity( responses_mean[epoch], np.unique(grating_vals[epoch]), 'ori') print('done') # calculate depth vs osi ratio (osi_beg/osi_end) print('calcuating osi ratio as a function of depth...', end='', flush=True) depths = np.array([clusters.depths[c] for c in cluster_ids]) ratios = np.array([osi['beg'][c] / osi['end'][c] for c in range(len(cluster_ids))]) print('done') # calculate fraction of visual neurons by depth print('calcuating fraction of visual clusters by depth...', end='', flush=True) n_bins = 10 min_depth = np.min(clusters['depths']) max_depth = np.max(clusters['depths']) depth_limits = np.linspace(min_depth - 1, max_depth, n_bins + 1) depth_avg = (depth_limits[:-1] + depth_limits[1:]) / 2 # aggregate clusters clusters_binned = {epoch: [] for epoch in epochs} frac_responsive = {epoch: [] for epoch in epochs} cids = cluster_ids for epoch in epochs: # just look at responsive clusters during this epoch cids_tmp = cids[resp[epoch]] for d in range(n_bins): lo_limit = depth_limits[d] up_limit = depth_limits[d + 1] # clusters.depth index is cluster id cids_curr_depth = np.where( (lo_limit < clusters.depths) & (clusters.depths <= up_limit))[0] clusters_binned[epoch].append(cids_curr_depth) frac_responsive[epoch].append(np.sum( np.isin(cids_tmp, cids_curr_depth)) / len(cids_curr_depth)) # package for plotting responsive = {'fraction': frac_responsive, 'depth': depth_avg} print('done') # calculate PSTH averaged over all clusters/orientations print('calcuating average PSTH...', end='', flush=True) peths = {epoch: [] for epoch in epochs} peths_avg = {epoch: [] for epoch in epochs} for epoch in epochs: stim_ids = np.unique(grating_vals[epoch]) peths[epoch] = {i: None for i in range(len(stim_ids))} peths_avg_tmp = [] for i, stim_id in enumerate(stim_ids): curr_stim_idxs = np.where(grating_vals[epoch] == stim_id) align_times = grating_times[epoch][curr_stim_idxs, 0][0] peths[epoch][i], _ = calculate_peths( spikes.times[mask_times], spikes.clusters[mask_times], cluster_ids, align_times, pre_time=pre_time, post_time=post_time, bin_size=bin_size, smoothing=smoothing, return_fr=True) peths_avg_tmp.append( np.mean(peths[epoch][i]['means'], axis=0, keepdims=True)) peths_avg_tmp = np.vstack(peths_avg_tmp) peths_avg[epoch] = { 'mean': np.mean(peths_avg_tmp, axis=0), 'std': np.std(peths_avg_tmp, axis=0) / np.sqrt(peths_avg_tmp.shape[0])} peths_avg['bin_size'] = bin_size peths_avg['on_idx'] = int(pre_time / bin_size) peths_avg['off_idx'] = peths_avg['on_idx'] + int(2 / bin_size) print('done') # compute rasters for entire orientation sequence at beg/end epoch if plot_summary: print('computing rasters for example stimulus sequences...', end='', flush=True) r = {epoch: None for epoch in epochs} r_times = {epoch: None for epoch in epochs} r_clusters = {epoch: None for epoch in epochs} for epoch in epochs: # restrict activity to a single stim series; assumes each possible grating direction # is displayed before repeating n_stims = len(np.unique(grating_vals[epoch])) mask_idxs_e = (spikes.times >= grating_times[epoch][:n_stims].min()) & \ (spikes.times <= grating_times[epoch][:n_stims].max()) r_tmp, r_times[epoch], r_clusters[epoch] = bincount2D( spikes.times[mask_idxs_e], spikes.clusters[mask_idxs_e], bin_size) # order activity by anatomical depth of neurons d = dict(zip(spikes.clusters[mask_idxs_e], spikes.depths[mask_idxs_e])) y = sorted([[i, d[i]] for i in d]) isort = np.argsort([x[1] for x in y]) r[epoch] = r_tmp[isort, :] # package for plotting rasters = {'spikes': r, 'times': r_times, 'clusters': r_clusters, 'bin_size': bin_size} print('done') # ------------------------------------------------- # compute psths and rasters for individual clusters # ------------------------------------------------- if plot_selected: print('computing psths and rasters for clusters...', end='', flush=True) if len(cluster_ids_selected) == 0: if (n_rand_clusters < len(cluster_ids)): cluster_idxs = np.random.choice(cluster_ids, size=n_rand_clusters, replace=False) else: cluster_idxs = cluster_ids else: cluster_idxs = cluster_ids_selected mean_responses = {cluster: {epoch: [] for epoch in epochs} for cluster in cluster_idxs} osis = {cluster: {epoch: [] for epoch in epochs} for cluster in cluster_idxs} binned = {cluster: {epoch: [] for epoch in epochs} for cluster in cluster_idxs} for cluster_idx in cluster_idxs: cluster = np.where(cluster_ids == cluster_idx)[0] for epoch in epochs: mean_responses[cluster_idx][epoch] = responses_mean[epoch][cluster, :][0] osis[cluster_idx][epoch] = osi[epoch][cluster] stim_ids = np.unique(grating_vals[epoch]) binned[cluster_idx][epoch] = {j: None for j in range(len(stim_ids))} for j, stim_id in enumerate(stim_ids): curr_stim_idxs = np.where(grating_vals[epoch] == stim_id) align_times = grating_times[epoch][curr_stim_idxs, 0][0] _, binned[cluster_idx][epoch][j] = calculate_peths( spikes.times[mask_times], spikes.clusters[mask_times], [cluster_idx], align_times, pre_time=pre_time, post_time=post_time, bin_size=bin_size) print('done') # -------------- # output figures # -------------- print('producing figures...', end='') if plot_summary: if save_dir is None: save_file = None else: if not os.path.exists(save_dir): os.makedirs(save_dir) save_file = os.path.join(save_dir, 'grating_summary_figure.' + format) fig_gr_summary = plot_summary_figure( ratios=ratios, depths=depths, responsive=responsive, peths_avg=peths_avg, osi=osi, ori_pref=ori_pref, responses_mean=responses_mean, rasters=rasters, save_file=save_file) fig_gr_summary.suptitle('Summary Grating Responses') fig_dict['gr_summary'] = fig_gr_summary if plot_selected: if save_dir is None: save_file = None else: save_file = os.path.join(save_dir, 'grating_random_responses.' + format) fig_gr_selected = plot_psths_and_rasters( mean_responses, binned, osis, grating_vals, on_idx=peths_avg['on_idx'], off_idx=peths_avg['off_idx'], bin_size=bin_size, save_file=save_file) fig_gr_selected.suptitle('Selected Units Grating Responses') print('done') fig_dict['gr_selected'] = fig_gr_selected # ----------------------------- # package up and return metrics # ----------------------------- metrics = { 'osi': osi, 'orientation_pref': ori_pref, 'frac_resp_by_depth': responsive, } return fig_dict, metrics
def make(self, key): clusters_spk_depths, clusters_spk_times, clusters_ids = \ (ephys.DefaultCluster & key).fetch( 'cluster_spikes_depths', 'cluster_spikes_times', 'cluster_id') spikes_depths = np.hstack(clusters_spk_depths) spikes_times = np.hstack(clusters_spk_times) spikes_clusters = np.hstack( [[cluster_id]*len(cluster_spk_depths) for (cluster_id, cluster_spk_depths) in zip(clusters_ids, clusters_spk_depths)]) if key['event'] == 'movement': q = behavior.TrialSet.Trial * wheel.MovementTimes & key & 'trial_feedback_type=1' else: q = behavior.TrialSet.Trial & key & 'trial_feedback_type=1' trials = q.fetch() bin_size_depth = 80 min_depth = np.nanmin(spikes_depths) max_depth = np.nanmax(spikes_depths) bin_edges = np.arange(min_depth, max_depth, bin_size_depth) spk_bin_ids = np.digitize(spikes_depths, bin_edges) edges = np.hstack([bin_edges, [bin_edges[-1]+bin_size_depth]]) key.update(trial_type='Correct All', depth_bin_centers=(edges[:-1] + edges[1:])/2) if key['event'] == 'feedback': event_times = trials['trial_feedback_time'] elif key['event'] == 'stim on': event_times = trials['trial_stim_on_time'] elif key['event'] == 'movement': event_times = trials['movement_onset'] peth_list = [] baseline_list = [] for i in tqdm(np.arange(len(bin_edges)) + 1, position=0): f = spk_bin_ids == i spikes_ibin = spikes_times[f] spike_clusters = spikes_clusters[f] cluster_ids = np.unique(spike_clusters) peths, binned_spikes = singlecell.calculate_peths( spikes_ibin, spike_clusters, cluster_ids, event_times, pre_time=0.3, post_time=1) if len(peths.means): time = peths.tscale peth = np.sum(peths.means, axis=0) baseline = peth[np.logical_and(time > -0.3, time < 0)] mean_bsl = np.mean(baseline) peth_list.append(peth) baseline_list.append(mean_bsl) else: peth_list.append(np.zeros_like(peths.tscale)) baseline_list.append(0) key.update(depth_peth=np.vstack(peth_list), depth_baseline=np.array(baseline_list), time_bin_centers=peths.tscale) self.insert1(key, skip_duplicates=True)