def image_raw_data(raw, fs, chn_coords=None, cmap='bone', title=None, display=False, gain=-90, **kwargs): def gain2level(gain): return 10**(gain / 20) * 4 * np.array([-1, 1]) ylabel = 'Channel index' if chn_coords is None else 'Distance from probe tip (um)' title = title or 'Raw data' y = np.arange(raw.shape[1]) if chn_coords is None else chn_coords[:, 1] x = np.array([0, raw.shape[0] - 1]) / fs * 1e3 data = ImagePlot(raw, y=y, cmap=cmap) data.set_labels(title=title, xlabel='Time (ms)', ylabel=ylabel, clabel='Power (uV)') clim = gain2level(gain) data.set_clim(clim=clim) data.set_xlim(xlim=x) data.set_ylim() if display: ax, fig = plot_image(data.convert2dict(), **kwargs) return data.convert2dict(), fig, ax return data
def image_lfp_spectrum_plot(lfp_power, lfp_freq, chn_coords, chn_inds, freq_range=(0, 300), avg_across_depth=False, cmap='viridis', display=False): """ Prepare data for 2D image plot of LFP power spectrum along depth of probe :param lfp_power: :param lfp_freq: :param chn_coords: :param chn_inds: :param freq_range: :param avg_across_depth: Whether to average across channels at same depth :param cmap: :param display: generate figure :return: ImagePlot object, if display=True also returns matplotlib fig and ax objects """ freq_idx = np.where((lfp_freq >= freq_range[0]) & (lfp_freq < freq_range[1]))[0] freqs = lfp_freq[freq_idx] lfp = np.take(lfp_power[freq_idx], chn_inds, axis=1) lfp_db = 10 * np.log10(lfp) lfp_db[np.isinf(lfp_db)] = np.nan x = freqs y = chn_coords[:, 1] # Average across channels that are at the same depth if avg_across_depth: chn_depth, chn_idx, chn_count = np.unique(chn_coords[:, 1], return_index=True, return_counts=True) chn_idx_eq = np.copy(chn_idx) chn_idx_eq[np.where(chn_count == 2)] += 1 lfp_db = np.apply_along_axis( lambda a: np.mean([a[chn_idx], a[chn_idx_eq]], axis=0), 1, lfp_db) x = freqs y = chn_depth data = ImagePlot(lfp_db, x=x, y=y, cmap=cmap) data.set_labels(title='LFP Power Spectrum', xlabel='Frequency (Hz)', ylabel='Distance from probe tip (um)', clabel='LFP Power (dB)') data.set_clim(clim=np.quantile(lfp_db, [0.1, 0.9])) if display: fig, ax = plot_image(data.convert2dict()) return data.convert2dict(), fig, ax return data
def image_rms_plot(rms_amps, rms_times, chn_coords, chn_inds, avg_across_depth=False, median_subtract=True, cmap='plasma', band='AP', display=False): """ Prepare data for 2D image plot of RMS data along depth of probe :param rms_amps: :param rms_times: :param chn_coords: :param chn_inds: :param avg_across_depth: Whether to average across channels at same depth :param median_subtract: Whether to apply median subtraction correction :param cmap: :param band: Frequency band of rms data, can be either 'LF' or 'AP' :param display: generate figure :return: ImagePlot object, if display=True also returns matplotlib fig and ax objects """ rms = rms_amps[:, chn_inds] * 1e6 x = rms_times y = chn_coords[:, 1] if avg_across_depth: chn_depth, chn_idx, chn_count = np.unique(chn_coords[:, 1], return_index=True, return_counts=True) chn_idx_eq = np.copy(chn_idx) chn_idx_eq[np.where(chn_count == 2)] += 1 rms = np.apply_along_axis( lambda a: np.mean([a[chn_idx], a[chn_idx_eq]], axis=0), 1, rms) y = chn_depth if median_subtract: median = np.mean(np.apply_along_axis(lambda a: np.median(a), 1, rms)) rms = np.apply_along_axis(lambda a: a - np.median(a), 1, rms) + median data = ImagePlot(rms, x=x, y=y, cmap=cmap) data.set_labels(title=f'{band} RMS', xlabel='Time (s)', ylabel='Distance from probe tip (um)', clabel=f'{band} RMS (uV)') data.set_clim(clim=np.quantile(rms, [0.1, 0.9])) if display: fig, ax = plot_image(data.convert2dict()) return data.convert2dict(), fig, ax return data
def image_crosscorr_plot(spike_depths, spike_times, chn_coords, t_bin=0.05, d_bin=40, cmap='viridis', display=False, title=None, **kwargs): """ Prepare data for 2D cross correlation plot of data across depth :param spike_depths: :param spike_times: :param chn_coords: :param t_bin: t_bin: time bin to average across (see also brainbox.processing.bincount2D) :param d_bin: depth bin to average across (see also brainbox.processing.bincount2D) :param cmap: :param display: generate figure :return: ImagePlot object, if display=True also returns matploltlib fig and ax objects """ title = title or 'Correlation' n, x, y = bincount2D(spike_times, spike_depths, t_bin, d_bin, ylim=[0, np.max(chn_coords[:, 1])]) corr = np.corrcoef(n) corr[np.isnan(corr)] = 0 data = ImagePlot(corr, x=y, y=y, cmap=cmap) data.set_labels(title=title, xlabel='Distance from probe tip (um)', ylabel='Distance from probe tip (um)', clabel='Correlation') if display: ax, fig = plot_image(data.convert2dict(), **kwargs) return data.convert2dict(), fig, ax return data
def image_fr_plot(spike_depths, spike_times, chn_coords, t_bin=0.05, d_bin=5, cmap='binary', display=False, title=None, **kwargs): """ Prepare data 2D raster plot of firing rate across recording :param spike_depths: :param spike_times: :param chn_coords: :param t_bin: time bin to average across (see also brainbox.processing.bincount2D) :param d_bin: depth bin to average across (see also brainbox.processing.bincount2D) :param cmap: :param display: generate figure :return: ImagePlot object, if display=True also returns matplotlib fig and ax objects """ title = title or 'Firing Rate' n, x, y = bincount2D(spike_times, spike_depths, t_bin, d_bin, ylim=[0, np.max(chn_coords[:, 1])]) fr = n.T / t_bin data = ImagePlot(fr, x=x, y=y, cmap=cmap) data.set_labels(title=title, xlabel='Time (s)', ylabel='Distance from probe tip (um)', clabel='Firing Rate (Hz)') data.set_clim(clim=(np.min(np.mean(fr, axis=0)), np.max(np.mean(fr, axis=0)))) if display: ax, fig = plot_image(data.convert2dict(), **kwargs) return data.convert2dict(), fig, ax return data
def plot_cdf(spike_amps, spike_depths, spike_times, n_amp_bins=10, d_bin=40, amp_range=None, d_range=None, display=False, cmap='hot', ax=None): """ Plot cumulative amplitude of spikes across depth :param spike_amps: :param spike_depths: :param spike_times: :param n_amp_bins: number of amplitude bins to use :param d_bin: the value of the depth bins in um (default is 40 um) :param amp_range: amp range to use [amp_min, amp_max], if not given automatically computed from spike_amps :param d_range: depth range to use, by default [0, 3840] :param display: whether or not to display plot :param cmap: :return: """ amp_range = amp_range or np.quantile(spike_amps, (0, 0.9)) amp_bins = np.linspace(amp_range[0], amp_range[1], n_amp_bins) d_range = d_range or [0, 3840] depth_bins = np.arange(d_range[0], d_range[1] + d_bin, d_bin) t_bin = np.max(spike_times) def histc(x, bins): map_to_bins = np.digitize( x, bins ) # Get indices of the bins to which each value in input array belongs. res = np.zeros(bins.shape) for el in map_to_bins: res[el - 1] += 1 # Increment appropriate bin. return res cdfs = np.empty((len(depth_bins) - 1, n_amp_bins)) for d in range(len(depth_bins) - 1): spikes = np.bitwise_and(spike_depths > depth_bins[d], spike_depths <= depth_bins[d + 1]) h = histc(spike_amps[spikes], amp_bins) / t_bin hcsum = np.cumsum(h[::-1]) cdfs[d, :] = hcsum[::-1] cdfs[cdfs == 0] = np.nan data = ImagePlot(cdfs.T, x=amp_bins * 1e6, y=depth_bins[:-1], cmap=cmap) data.set_labels(title='Cumulative Amplitude', xlabel='Spike amplitude (uV)', ylabel='Distance from probe tip (um)', clabel='Firing Rate (Hz)') if display: ax, fig = plot_image(data.convert2dict(), fig_kwargs={'figsize': [3, 7]}, ax=ax) return data.convert2dict(), fig, ax return data