Ejemplo n.º 1
0
def image_raw_data(raw,
                   fs,
                   chn_coords=None,
                   cmap='bone',
                   title=None,
                   display=False,
                   gain=-90,
                   **kwargs):
    def gain2level(gain):
        return 10**(gain / 20) * 4 * np.array([-1, 1])

    ylabel = 'Channel index' if chn_coords is None else 'Distance from probe tip (um)'
    title = title or 'Raw data'

    y = np.arange(raw.shape[1]) if chn_coords is None else chn_coords[:, 1]

    x = np.array([0, raw.shape[0] - 1]) / fs * 1e3

    data = ImagePlot(raw, y=y, cmap=cmap)
    data.set_labels(title=title,
                    xlabel='Time (ms)',
                    ylabel=ylabel,
                    clabel='Power (uV)')
    clim = gain2level(gain)
    data.set_clim(clim=clim)
    data.set_xlim(xlim=x)
    data.set_ylim()

    if display:
        ax, fig = plot_image(data.convert2dict(), **kwargs)
        return data.convert2dict(), fig, ax

    return data
Ejemplo n.º 2
0
def image_lfp_spectrum_plot(lfp_power,
                            lfp_freq,
                            chn_coords,
                            chn_inds,
                            freq_range=(0, 300),
                            avg_across_depth=False,
                            cmap='viridis',
                            display=False):
    """
    Prepare data for 2D image plot of LFP power spectrum along depth of probe

    :param lfp_power:
    :param lfp_freq:
    :param chn_coords:
    :param chn_inds:
    :param freq_range:
    :param avg_across_depth: Whether to average across channels at same depth
    :param cmap:
    :param display: generate figure
    :return: ImagePlot object, if display=True also returns matplotlib fig and ax objects
    """

    freq_idx = np.where((lfp_freq >= freq_range[0])
                        & (lfp_freq < freq_range[1]))[0]
    freqs = lfp_freq[freq_idx]
    lfp = np.take(lfp_power[freq_idx], chn_inds, axis=1)
    lfp_db = 10 * np.log10(lfp)
    lfp_db[np.isinf(lfp_db)] = np.nan
    x = freqs
    y = chn_coords[:, 1]

    # Average across channels that are at the same depth
    if avg_across_depth:
        chn_depth, chn_idx, chn_count = np.unique(chn_coords[:, 1],
                                                  return_index=True,
                                                  return_counts=True)
        chn_idx_eq = np.copy(chn_idx)
        chn_idx_eq[np.where(chn_count == 2)] += 1

        lfp_db = np.apply_along_axis(
            lambda a: np.mean([a[chn_idx], a[chn_idx_eq]], axis=0), 1, lfp_db)

        x = freqs
        y = chn_depth

    data = ImagePlot(lfp_db, x=x, y=y, cmap=cmap)
    data.set_labels(title='LFP Power Spectrum',
                    xlabel='Frequency (Hz)',
                    ylabel='Distance from probe tip (um)',
                    clabel='LFP Power (dB)')
    data.set_clim(clim=np.quantile(lfp_db, [0.1, 0.9]))

    if display:
        fig, ax = plot_image(data.convert2dict())
        return data.convert2dict(), fig, ax

    return data
Ejemplo n.º 3
0
def image_rms_plot(rms_amps,
                   rms_times,
                   chn_coords,
                   chn_inds,
                   avg_across_depth=False,
                   median_subtract=True,
                   cmap='plasma',
                   band='AP',
                   display=False):
    """
    Prepare data for 2D image plot of RMS data along depth of probe

    :param rms_amps:
    :param rms_times:
    :param chn_coords:
    :param chn_inds:
    :param avg_across_depth: Whether to average across channels at same depth
    :param median_subtract: Whether to apply median subtraction correction
    :param cmap:
    :param band: Frequency band of rms data, can be either 'LF' or 'AP'
    :param display: generate figure
    :return: ImagePlot object, if display=True also returns matplotlib fig and ax objects
    """

    rms = rms_amps[:, chn_inds] * 1e6
    x = rms_times
    y = chn_coords[:, 1]

    if avg_across_depth:
        chn_depth, chn_idx, chn_count = np.unique(chn_coords[:, 1],
                                                  return_index=True,
                                                  return_counts=True)
        chn_idx_eq = np.copy(chn_idx)
        chn_idx_eq[np.where(chn_count == 2)] += 1
        rms = np.apply_along_axis(
            lambda a: np.mean([a[chn_idx], a[chn_idx_eq]], axis=0), 1, rms)
        y = chn_depth

    if median_subtract:
        median = np.mean(np.apply_along_axis(lambda a: np.median(a), 1, rms))
        rms = np.apply_along_axis(lambda a: a - np.median(a), 1, rms) + median

    data = ImagePlot(rms, x=x, y=y, cmap=cmap)
    data.set_labels(title=f'{band} RMS',
                    xlabel='Time (s)',
                    ylabel='Distance from probe tip (um)',
                    clabel=f'{band} RMS (uV)')
    data.set_clim(clim=np.quantile(rms, [0.1, 0.9]))

    if display:
        fig, ax = plot_image(data.convert2dict())
        return data.convert2dict(), fig, ax

    return data
Ejemplo n.º 4
0
def image_crosscorr_plot(spike_depths,
                         spike_times,
                         chn_coords,
                         t_bin=0.05,
                         d_bin=40,
                         cmap='viridis',
                         display=False,
                         title=None,
                         **kwargs):
    """
    Prepare data for 2D cross correlation plot of data across depth

    :param spike_depths:
    :param spike_times:
    :param chn_coords:
    :param t_bin: t_bin: time bin to average across (see also brainbox.processing.bincount2D)
    :param d_bin: depth bin to average across (see also brainbox.processing.bincount2D)
    :param cmap:
    :param display: generate figure
    :return: ImagePlot object, if display=True also returns matploltlib fig and ax objects
    """

    title = title or 'Correlation'
    n, x, y = bincount2D(spike_times,
                         spike_depths,
                         t_bin,
                         d_bin,
                         ylim=[0, np.max(chn_coords[:, 1])])
    corr = np.corrcoef(n)
    corr[np.isnan(corr)] = 0

    data = ImagePlot(corr, x=y, y=y, cmap=cmap)
    data.set_labels(title=title,
                    xlabel='Distance from probe tip (um)',
                    ylabel='Distance from probe tip (um)',
                    clabel='Correlation')

    if display:
        ax, fig = plot_image(data.convert2dict(), **kwargs)
        return data.convert2dict(), fig, ax

    return data
Ejemplo n.º 5
0
def image_fr_plot(spike_depths,
                  spike_times,
                  chn_coords,
                  t_bin=0.05,
                  d_bin=5,
                  cmap='binary',
                  display=False,
                  title=None,
                  **kwargs):
    """
    Prepare data 2D raster plot of firing rate across recording

    :param spike_depths:
    :param spike_times:
    :param chn_coords:
    :param t_bin: time bin to average across (see also brainbox.processing.bincount2D)
    :param d_bin: depth bin to average across (see also brainbox.processing.bincount2D)
    :param cmap:
    :param display: generate figure
    :return: ImagePlot object, if display=True also returns matplotlib fig and ax objects
    """

    title = title or 'Firing Rate'
    n, x, y = bincount2D(spike_times,
                         spike_depths,
                         t_bin,
                         d_bin,
                         ylim=[0, np.max(chn_coords[:, 1])])
    fr = n.T / t_bin

    data = ImagePlot(fr, x=x, y=y, cmap=cmap)
    data.set_labels(title=title,
                    xlabel='Time (s)',
                    ylabel='Distance from probe tip (um)',
                    clabel='Firing Rate (Hz)')
    data.set_clim(clim=(np.min(np.mean(fr, axis=0)),
                        np.max(np.mean(fr, axis=0))))
    if display:
        ax, fig = plot_image(data.convert2dict(), **kwargs)
        return data.convert2dict(), fig, ax

    return data
Ejemplo n.º 6
0
def plot_cdf(spike_amps,
             spike_depths,
             spike_times,
             n_amp_bins=10,
             d_bin=40,
             amp_range=None,
             d_range=None,
             display=False,
             cmap='hot',
             ax=None):
    """
    Plot cumulative amplitude of spikes across depth
    :param spike_amps:
    :param spike_depths:
    :param spike_times:
    :param n_amp_bins: number of amplitude bins to use
    :param d_bin: the value of the depth bins in um (default is 40 um)
    :param amp_range: amp range to use [amp_min, amp_max], if not given automatically computed from spike_amps
    :param d_range: depth range to use, by default [0, 3840]
    :param display: whether or not to display plot
    :param cmap:
    :return:
    """

    amp_range = amp_range or np.quantile(spike_amps, (0, 0.9))
    amp_bins = np.linspace(amp_range[0], amp_range[1], n_amp_bins)
    d_range = d_range or [0, 3840]
    depth_bins = np.arange(d_range[0], d_range[1] + d_bin, d_bin)
    t_bin = np.max(spike_times)

    def histc(x, bins):
        map_to_bins = np.digitize(
            x, bins
        )  # Get indices of the bins to which each value in input array belongs.
        res = np.zeros(bins.shape)

        for el in map_to_bins:
            res[el - 1] += 1  # Increment appropriate bin.
        return res

    cdfs = np.empty((len(depth_bins) - 1, n_amp_bins))
    for d in range(len(depth_bins) - 1):
        spikes = np.bitwise_and(spike_depths > depth_bins[d],
                                spike_depths <= depth_bins[d + 1])
        h = histc(spike_amps[spikes], amp_bins) / t_bin
        hcsum = np.cumsum(h[::-1])
        cdfs[d, :] = hcsum[::-1]

    cdfs[cdfs == 0] = np.nan

    data = ImagePlot(cdfs.T, x=amp_bins * 1e6, y=depth_bins[:-1], cmap=cmap)
    data.set_labels(title='Cumulative Amplitude',
                    xlabel='Spike amplitude (uV)',
                    ylabel='Distance from probe tip (um)',
                    clabel='Firing Rate (Hz)')

    if display:
        ax, fig = plot_image(data.convert2dict(),
                             fig_kwargs={'figsize': [3, 7]},
                             ax=ax)
        return data.convert2dict(), fig, ax

    return data