Esempio n. 1
0
def playsta(sta, frame_duration=None, cmap=None, centerzero=True, **kwargs):
    """
    Create a looped animation for a single STA with 3 dimensions.

    Parameters
    ---------
    cmap:
        Colormap to be used. Defaults to the specified colormap in the
        config file.
    centerzero:
        Center the colormap around zero if True.
    interval:
        Frame rate for the animation in ms.
    repeat_delay:
        Time to wait before the animation is repeated in ms.

    Note
    ----
    The returned animation can be saved like so:

    >>> ani = playsta(sta)
    >>> ani.save('wheretosave/sta.gif', writer='imagemagick', fps=10)
    """
    check_interactive_backend()

    if cmap is None:
        cmap = iof.config('colormap')
    if centerzero:
        vmax = asc.absmax(sta)
        vmin = asc.absmin(sta)
    else:
        vmax, vmin = sta.max(), sta.min()
    ims = []
    fig = plt.figure()
    ax = plt.gca()
    for i in range(sta.shape[-1]):
        im = ax.imshow(sta[:, :, i],
                       animated=True,
                       cmap=cmap,
                       vmin=vmin,
                       vmax=vmax)

        ims.append([im])  # Needs to be a list of lists
    ani = animation.ArtistAnimation(fig, ims, **kwargs)

    return ani
Esempio n. 2
0
def stashow(sta, ax=None, cbar=True, **kwargs):
    """
    Plot STA in a nice way with proper colormap and colorbar.

    STA can be single frame from checkerflicker or whole STA
    from stripeflicker.

    Following kwargs are available:
        imshow
            extent: Change the labels of the axes. [xmin, xmax, ymin, ymax]
            aspect: Aspect ratio of the image. 'auto', 'equal'
            cmap:  Colormap to be used. Default is set in config
        colorbar
            size: Width of the colorbar as percentage of image dimension
                  Default is 2%
            ticks: Where the ticks should be placed on the colorbar.
            format: Format for the tick labels. Default is '%.2f'
    Usage:
        ax = plt.subplot(111)
        stashow(sta, ax)
    """
    vmax = asc.absmax(sta)
    vmin = asc.absmin(sta)

    # Make a dictionary for imshow and colorbar kwargs
    imshowkw = {'cmap': iof.config('colormap'), 'vmin': vmin, 'vmax': vmax}
    cbarkw = {'size': '2%', 'ticks': [vmin, vmax], 'format': '%.2f'}
    for key, val in kwargs.items():
        if key in ['extent', 'aspect', 'cmap']:
            imshowkw.update({key: val})
        elif key in ['size', 'ticks', 'format']:
            cbarkw.update({key: val})
        else:
            raise ValueError(f'Unknown kwarg: {key}')

    if ax is None:
        ax = plt.gca()

    im = ax.imshow(sta, **imshowkw)
    spineless(ax)
    if cbar:
        colorbar(im, **cbarkw)
    return im
Esempio n. 3
0
                ax.xaxis.set_ticklabels([])
            elif row % nsubrows == 1:
                ax.plot(motionfilt_r[mode_i, :], color='k')
                if col == 0: ax.set_ylabel('Magnitude')
                ax.set_ylim(motionfilt_r.min(), motionfilt_r.max())
                ax.xaxis.set_ticklabels([])
            elif row % nsubrows == 2:
                ax.plot(motionfilt_theta[mode_i, :], color='r')
                ax.yaxis.set_ticks([-np.pi, 0, np.pi])
                ax.yaxis.set_ticklabels(['-π', 0, 'π'])
                ax.xaxis.set_ticklabels([])
            elif row % nsubrows == nsubrows - 1:
                im = ax.imshow(cells[mode_i, :],
                               cmap='RdBu_r',
                               vmin=asc.absmin(cells),
                               vmax=asc.absmax(cells),
                               aspect='auto',
                               interpolation='nearest')
                ax.set_xticks(np.array([0, .25, .5, .75, 1]) * cells.shape[-1])
                ax.xaxis.set_ticklabels(
                    np.round((ax.get_xticks() * st.frame_duration), 1))
                ax.set_xlabel('Time [s]')
                if col == 0: ax.set_ylabel('Cells')
            if mode_i == n_components - 1:
                plf.colorbar(im)
    # fig.tight_layout()
    fig.subplots_adjust(wspace=0.3)
    fig.savefig(
        f'/Users/ycan/Downloads/2020-05-27_meeting/cca_DScells_shuffled.pdf')
    plt.show()
stc2 = calc_stc(spikes, stimulus, filter_length)

fig, axes = plt.subplots(2, 1)
axes[0].plot(eigvals, 'o')
ax1 = axes[1]
ax1.plot(sta)
ax1.plot(eigvecs[:, 0])
ax1.plot(eigvecs[:, -1])

#%%
sp_mean = spikes.mean()
stim_covar = np.cov(rw, rowvar=False)
term1 = stc + (2 * (sta @ sta.T)) / sp_mean
bmel = np.linalg.inv(term1) @ sta
cmel = (np.linalg.inv(stim_covar) - sp_mean * np.linalg.inv(term1)) / 2

fig2, axes2 = plt.subplots(2, 2)
axes2 = axes2.flat
axes2[0].plot(bmel)
axes2[0].plot(sta, ls='dashed', color='grey', label='STA')
axes2[1].imshow(cmel,
                cmap='seismic',
                vmin=asc.absmin(cmel),
                vmax=asc.absmax(cmel))

v, w = np.linalg.eigh(cmel)
eiginds = [0, -1]
axes2[2].plot(v, 'ko')
axes2[2].plot(v[eiginds[0]], 'o')
axes2[2].plot(v[eiginds[1]], 'o')
axes2[3].plot(w[:, eiginds])
Esempio n. 5
0
def cca_omb_components(exp: str,
                       stim_nr: int,
                       n_components: int = 6,
                       regularization=None,
                       filter_length=None,
                       cca_solver: str = 'macke',
                       maxframes=None,
                       shufflespikes: bool = False,
                       exclude_allzero_spike_rows: bool = True,
                       savedir: str = None,
                       savefig: bool = True,
                       sort_by_nspikes: bool = True,
                       select_cells: list = None,
                       plot_first_ncells: int = None,
                       whiten: bool = False):
    """
    Analyze OMB responses using cannonical correlation analysis and plot the results.

    Parameters
    ---
    n_components:
        Number of components that will be requested from the CCA anaylsis. More numbers mean
        the algortihm will stop at a later point. That means components of analyses with fewer
        n_components are going to be identical to the first n components of the higher-number
        component analyses.
    regularization:
        The regularization parameter to be passed onto rcca.CCA. Not relevant for macke
    filter_length:
        The length of the time window to be considered in the past for the stimulus and the responses.
        Can be different for stimulus and response, if a tuple is given.
    cca_solver:
        Which CCA solver to use. Options are `rcca` and `macke`(default)
    maxframes: int
        Number of frames to load in the the experiment object. Used to avoid memory and performance
        issues.
    shufflespikes: bool
        Whether to randomize the spikes, to validate the results
    exclude_allzero_spike_rows:
        Exclude all cells which have zero spikes for the duration of the stimulus.
    savedir: str
        Custom directory to save the figures and data files. If None, will be saved in the experiment
        directory under appropritate path.
    savefig: bool
        Whether to save the figures.
    sort_by_nspikes: bool
        Wheter to sort the cell weights array by the number of spikes during the stimulus.
    select_cells: list
       A list of indexes for the subset of cells to perform the analysis for.
    plot_first_ncells: int
        Number of cells to plot in the cell plots.
    """
    if regularization is None:
        regularization = 0

    st = OMB(exp, stim_nr, maxframes=maxframes)

    if filter_length is None:
        filter_length = st.filter_length

    if type(filter_length) is int:
        filter_length = (filter_length, filter_length)

    if type(savedir) is str:
        savedir = Path(savedir)

    if savedir is None:
        savedir = st.stim_dir / 'CCA'
        savedir.mkdir(exist_ok=True, parents=True)

    spikes = st.allspikes()
    nonzerospikerows = ~np.isclose(spikes.sum(axis=1), 0)
    # Set the mean to zero for spikes
    spikes -= spikes.mean(axis=1)[:, None]

    bgsteps = st.bgsteps

    if select_cells is not None:
        if type(select_cells) is not np.array:
            select_cells = np.array(select_cells)
        spikes = spikes[select_cells]
        st.nclusters = len(select_cells)
        # Convert to list for better string representation
        # np.array is printed as "array([....])"
        # with newline characters which is problematic in filenames
        select_cells = list(select_cells)

    # Exclude rows that have no spikes throughout
    if exclude_allzero_spike_rows:
        spikes = spikes[nonzerospikerows, :]

    nspikes_percell = spikes.sum(axis=1)

    if shufflespikes:
        spikes = spikeshuffler.shufflebyrow(spikes)

    figsavename = f'{n_components=}_{shufflespikes=}_{select_cells=}_{regularization=}_{filter_length=}_{whiten=}'
    # If the file length gets too long due to the list of selected cells, summarize it.
    if len(figsavename) > 200:
        figsavename = f'{n_components=}_{shufflespikes=}_select_cells={len(select_cells)}cells-index{select_cells[0]}to{select_cells[-1]}_{regularization=}_{filter_length=}_{whiten=}'

    #sp_train, sp_test, stim_train, stim_test = train_test_split(spikes, bgsteps)

    stimulus = mft.packdims(st.bgsteps, filter_length[0])
    spikes = mft.packdims(spikes, filter_length[1])

    if cca_solver == 'rcca':
        resp_comps, stim_comps, cancorrs = cca_rcca(spikes, stimulus,
                                                    filter_length,
                                                    n_components,
                                                    regularization, whiten)

        # cells = np.swapaxes(spikes_res, 1, 0)
        # cells = cells.reshape((n_components, st.nclusters, filter_length[1]))
        # motionfilt_x = cca.ws[1][:filter_length[0]].T
        # motionfilt_y = cca.ws[1][filter_length[0]:].T
    elif cca_solver == 'macke':
        resp_comps, stim_comps, cancorrs = cca_macke(spikes, stimulus,
                                                     filter_length,
                                                     n_components)

    nsp_argsorted = np.argsort(nspikes_percell)
    resp_comps_sorted_nsp = resp_comps[:, nsp_argsorted, :]

    if sort_by_nspikes:
        resp_comps_toplot = resp_comps_sorted_nsp
    else:
        resp_comps_toplot = resp_comps

    if plot_first_ncells is not None:
        resp_comps_toplot = resp_comps_toplot[:, :plot_first_ncells, ...]

    motionfilt_r, motionfilt_theta = mft.cart2pol(stim_comps[:, 0, :],
                                                  stim_comps[:, 1, :])
    #%%
    nrows, ncols = plf.numsubplots(n_components)
    fig_cells, axes_cells = plt.subplots(nrows, ncols, figsize=(10, 10))

    for i in range(n_components):
        ax = axes_cells.flat[i]
        im = ax.imshow(resp_comps[i, :],
                       cmap='RdBu_r',
                       vmin=asc.absmin(resp_comps),
                       vmax=asc.absmax(resp_comps),
                       aspect='auto',
                       interpolation='nearest')
        ax.set_title(f'{i}')
    fig_cells.suptitle(f'Cells default order {shufflespikes=}')
    if savefig:
        fig_cells.savefig(savedir / f'{figsavename}_cells_default_order.pdf')
    plt.close(fig_cells)

    nsubplots = plf.numsubplots(n_components)
    height_list = [1, 1, 1, 3]  # ratios of the plots in each component

    # Create a time vector for the stimulus plots
    t_stim = -np.arange(0, filter_length[0] * st.frame_duration,
                        st.frame_duration)[::-1] * 1000
    t_response = -np.arange(0, filter_length[1] * st.frame_duration,
                            st.frame_duration)[::-1] * 1000
    xtick_loc_params = dict(nbins=4, steps=[2, 5, 10], integer=True)

    nsubrows = len(height_list)
    height_ratios = nsubplots[0] * height_list
    fig, axes = plt.subplots(nrows=nsubplots[0] * nsubrows,
                             ncols=nsubplots[1],
                             gridspec_kw={'height_ratios': height_ratios},
                             figsize=(11, 10))

    for row, ax_row in enumerate(axes):
        for col, ax in enumerate(ax_row):
            mode_i = int(row / nsubrows) * nsubplots[1] + col
            # ax.text(0.5, 0.5, f'{mode_i}')
            ax.set_yticks([])
            # Plot motion filters
            if row % nsubrows == 0:

                ax.plot(t_stim,
                        stim_comps[mode_i, 0, :],
                        marker='o',
                        markersize=1)
                ax.plot(t_stim,
                        stim_comps[mode_i, 1, :],
                        marker='o',
                        markersize=1)
                if col == 0:
                    ax.set_ylabel('Motion',
                                  rotation=0,
                                  ha='right',
                                  va='center')
                ax.set_ylim(stim_comps.min(), stim_comps.max())

                # Draw a horizontal line for zero and prevent rescaling of x-axis
                xlims = ax.get_xlim()
                ax.hlines(0,
                          *ax.get_xlim(),
                          colors='k',
                          linestyles='dashed',
                          alpha=0.3)
                ax.set_xlim(*xlims)

                # ax.set_title(f'Component {mode_i}', fontweight='bold')

                ax.xaxis.set_major_locator(MaxNLocator(**xtick_loc_params))

                if not mode_i == 0 or filter_length[0] == filter_length[1]:
                    ax.xaxis.set_ticklabels([])
                else:
                    ax.tick_params(axis='x', labelsize=8)

            # Plot magnitude of motion
            elif row % nsubrows == 1:
                ax.plot(t_stim,
                        motionfilt_r[mode_i, :],
                        color='k',
                        marker='o',
                        markersize=1)
                if col == 0:
                    ax.set_ylabel('Magnitude',
                                  rotation=0,
                                  ha='right',
                                  va='center')
                ax.set_ylim(motionfilt_r.min(), motionfilt_r.max())
                ax.xaxis.set_ticklabels([])
                ax.xaxis.set_major_locator(MaxNLocator(**xtick_loc_params))
            # Plot direction of motion
            elif row % nsubrows == 2:
                ax.plot(t_stim,
                        motionfilt_theta[mode_i, :],
                        color='r',
                        marker='o',
                        markersize=1)
                if mode_i == 0:
                    ax.yaxis.set_ticks([-np.pi, 0, np.pi])
                    ax.yaxis.set_ticklabels(['-π', 0, 'π'])
                ax.xaxis.set_ticklabels([])
                ax.xaxis.set_major_locator(MaxNLocator(**xtick_loc_params))
            # Plot cell weights
            elif row % nsubrows == nsubrows - 1:
                im = ax.imshow(resp_comps_toplot[mode_i, :],
                               cmap='RdBu_r',
                               vmin=asc.absmin(resp_comps),
                               vmax=asc.absmax(resp_comps),
                               aspect='auto',
                               interpolation='nearest',
                               extent=[
                                   t_response[0], t_response[-1], 0,
                                   resp_comps_toplot.shape[1]
                               ])
                ax.xaxis.set_major_locator(MaxNLocator(**xtick_loc_params))
                if row == axes.shape[0] - 1:
                    ax.set_xlabel('Time before spike [ms]')
                    # ax.set_xticks(np.array([0, .25, .5, .75, 1]) * cells_toplot.shape[-1])
                    # ax.xaxis.set_ticklabels(-np.round((ax.get_xticks()*st.frame_duration), 2)[::-1])
                else:
                    ax.xaxis.set_ticklabels([])

                plf.integerticks(ax, 5, which='y')
                if col == 0:
                    ax.set_ylabel(
                        f'Cells\n{"(sorted nsp)"*sort_by_nspikes}\n{("(first " + str(plot_first_ncells)+ " cells)")*(type(plot_first_ncells) is int) }',
                        rotation=0,
                        ha='right',
                        va='center')
                else:
                    ax.yaxis.set_ticklabels([])
                if mode_i == n_components - 1:
                    plf.colorbar(im)
            # Add ticks on the right side of the plots
            if col == nsubplots[1] - 1 and row % nsubrows != nsubrows - 1:
                plf.integerticks(ax, 3, which='y')
                ax.yaxis.tick_right()

    fig.suptitle(
        f'CCA components of {st.exp_foldername}\n{shufflespikes=} {n_components=}\n{sort_by_nspikes=}\n'
        + f'{select_cells=} {regularization=} {filter_length=}')
    fig.subplots_adjust(wspace=0.1, hspace=0.3)
    if savefig:
        fig.savefig(savedir / f'{figsavename}_cellsandcomponents.pdf')
    # plt.show()
    plt.close(fig)

    #%%
    fig_corrs = plt.figure()
    plt.plot(cancorrs, 'ko')
    # plt.ylim([0.17, 0.24])
    plt.xlabel('Component index')
    plt.ylabel('Correlation')
    plt.title(f'Cannonical correlations {shufflespikes=}')
    if savefig:
        fig_corrs.savefig(savedir / f'{figsavename}_correlation_coeffs.pdf')
    # plt.show()
    plt.close(fig_corrs)

    fig_nlt, axes_nlt = plt.subplots(nrows, ncols, figsize=(10, 10))

    stim_comps_flatter = stim_comps[:n_components].reshape(
        (n_components, 2 * filter_length[0])).T
    resp_comps_flatter = resp_comps[:n_components].reshape(
        (n_components, resp_comps.shape[1] * filter_length[1])).T

    # from IPython.core.debugger import Pdb; ipdb=Pdb(); ipdb.set_trace()
    # Reshape to perform the convolution as a matrix multiplication
    generator_stim = stimulus @ stim_comps_flatter
    generator_resp = spikes @ resp_comps_flatter

    for i, ax in enumerate(axes_nlt.flatten()):

        nonlinearity, bins = nlt.calc_nonlin(generator_resp[:, i],
                                             generator_stim[:, i])
        # ax.scatter(generator_stim, generator_resp, s=1, alpha=0.5, facecolor='grey')
        ax.plot(bins, nonlinearity, 'k')
        if i == 0:
            all_nonlinearities = np.empty((n_components, *nonlinearity.shape))
            all_bins = np.empty((n_components, *bins.shape))
        all_nonlinearities[i, ...] = nonlinearity
        all_bins[i, ...] = bins

    nlt_xlims = []
    nlt_ylims = []
    for i, ax in enumerate(axes_nlt.flatten()):
        xlim = ax.get_xlim()
        ylim = ax.get_ylim()

        nlt_xlims.extend(xlim)
        nlt_ylims.extend(ylim)
    nlt_maxx, nlt_minx = max(nlt_xlims), min(nlt_xlims)
    nlt_maxy, nlt_miny = max(nlt_ylims), min(nlt_ylims)

    for i, ax in enumerate(axes_nlt.flatten()):
        ax.set_xlim([nlt_minx, nlt_maxx])
        ax.set_ylim([nlt_miny, nlt_maxy])

    for i, axes_row in enumerate(axes_nlt):
        for j, ax in enumerate(axes_row):
            if i == nrows - 1:
                ax.set_xlabel('Generator (motion)')
            if j == 0:
                ax.set_ylabel('Generator (cells)')
            else:
                ax.yaxis.set_ticklabels([])
            ax.set_xlim([nlt_minx, nlt_maxx])
            ax.set_ylim([nlt_miny, nlt_maxy])

    fig_nlt.suptitle(f'Nonlinearities\n{figsavename}')
    if savefig:
        fig_nlt.savefig(savedir / f'{figsavename}_nonlinearity.png')
    plt.close(fig_nlt)
    keystosave = [
        'n_components', 'resp_comps', 'stim_comps', 'motionfilt_r',
        'motionfilt_theta', 'resp_comps_sorted_nsp', 'select_cells',
        'regularization', 'filter_length', 'all_nonlinearities', 'all_bins',
        'cca_solver'
    ]
    datadict = dict()
    for key in keystosave:
        datadict[key] = locals()[key]
    np.savez(savedir / figsavename, **datadict)
Esempio n. 6
0
import matplotlib.pyplot as plt

cells = cca.x_weights_.T

motionfilt_x = cca.y_weights_[:filter_length].T
motionfilt_y = cca.y_weights_[filter_length:].T

fig, axes = plt.subplots(*plf.numsubplots(n_components))

for i, ax in enumerate(axes.flat):
    ax.plot(motionfilt_x[i, :])
    ax.plot(motionfilt_y[i, :])
    ax.set_ylim(cca.y_weights_.min(), cca.y_weights_.max())
fig.suptitle('Stimulus components for each CCA filter')
plt.show()

im = plt.imshow(cells, cmap='RdBu_r', vmin=asc.absmin(cells), vmax=asc.absmax(cells))
plt.ylabel('Components')
plt.xlabel('Cells')
plf.colorbar(im)
plt.show()



plt.plot(cca.coef_[:, 0], cca.coef_[:, 1], 'ko')
plt.show()

offsets = np.arange(n_components)*.6
plt.plot(cca.y_weights_ + np.tile(offsets, (stimulus.shape[1], 1)), lw=0.5)
plt.hlines(offsets, 0, stimulus.shape[1], lw=.2)
Esempio n. 7
0
            ax.set_ylim(cca.y_weights_.min(), cca.y_weights_.max())
            ax.set_title(f'Component {mode_i}', fontweight='bold')
            ax.xaxis.set_ticklabels([])
        elif row % nsubrows == 1:
            ax.plot(motionfilt_r[mode_i, :], color='k')
            if col==0: ax.set_ylabel('Magnitude')
            ax.set_ylim(motionfilt_r.min(), motionfilt_r.max())
            ax.xaxis.set_ticklabels([])
        elif row % nsubrows == 2:
            ax.plot(motionfilt_theta[mode_i, :], color='r')
            ax.yaxis.set_ticks([-np.pi, 0, np.pi])
            ax.yaxis.set_ticklabels(['-π', 0, 'π'])
            ax.xaxis.set_ticklabels([])
        elif row % nsubrows == nsubrows-1:
            im = ax.imshow(cells[mode_i, :], cmap='RdBu_r',
                           vmin=asc.absmin(cells), vmax=asc.absmax(cells),
                           aspect='auto',
                         interpolation='nearest')
            ax.set_xticks(np.array([0, .25, .5, .75, 1]) * cells.shape[-1])
            ax.xaxis.set_ticklabels(np.round((ax.get_xticks()*st.frame_duration), 1))
            ax.set_xlabel('Time [s]')
            if col==0: ax.set_ylabel('Cells')

# fig.tight_layout()
fig.subplots_adjust(wspace=0.3)
selcels = '_'.join(map(str, selected_cells))
fig.savefig(f'/Users/ycan/Downloads/2020-05-27_meeting/{selcels}_shuffled{shuffle_spikes}.pdf')
plt.show()

#%%
Esempio n. 8
0
def stabrowser(sta, frame_duration=None, cmap=None, centerzero=True, **kwargs):
    """
    Returns an interactive plot to browse an spatiotemporal
    STA. Requires an interactive matplotlib backend.

    Parameters
    --------
    sta:
        Numpy array containing the STA. Last dimension should index time.
    frame_duration:
      Time between each frame. (optional)
    cmap:
      Colormap to use.
    centerzero:
      Whether to center the colormap around zero for diverging colormaps.

    Example
    ------
    >>> print(sta.shape) # (xpixels, ypixels, time)
    (75, 100, 40)
    >>> fig, slider = stabrowser(sta, frame_duration=1/60)

    Notes
    -----
    When calling the function, the slider is returned to prevent the reference
    to it getting destroyed and to keep it interactive.
    The dummy variable `_` can also be used.
    """
    check_interactive_backend()

    if cmap is None:
        cmap = iof.config('colormap')
    if centerzero:
        vmax = asc.absmax(sta)
        vmin = asc.absmin(sta)
    else:
        vmax, vmin = sta.max(), sta.min()

    imshowkwargs = dict(cmap=cmap, vmax=vmax, vmin=vmin, **kwargs)

    fig = plt.figure()
    ax = fig.add_axes([0.1, 0.1, 0.8, 0.8])

    initial_frame = 5

    axsl = fig.add_axes([0.25, 0.05, 0.65, 0.03])
    # For the slider to remain interactive, a reference to it should
    # be kept, so it set to a variable and is returned by the function
    slider_t = Slider(axsl,
                      'Frame before spike',
                      0,
                      sta.shape[-1] - 1,
                      valinit=initial_frame,
                      valstep=1,
                      valfmt='%2.0f')

    def update(frame):
        frame = int(frame)
        im = ax.get_images()[0]
        im.set_data(sta[..., frame])
        if frame_duration is not None:
            fig.suptitle(f'{frame*frame_duration*1000:4.0f} ms')
        fig.canvas.draw_idle()

    slider_t.on_changed(update)

    ax.imshow(sta[..., initial_frame], **imshowkwargs)
    ax.set_axis_off()
    plt.tight_layout()
    plt.subplots_adjust(wspace=.01, hspace=.01)
    return fig, slider_t
Esempio n. 9
0
def multistabrowser(stas,
                    frame_duration=None,
                    normalize=True,
                    cmap=None,
                    centerzero=True,
                    **kwargs):
    """
    Returns an interactive plot to browse multiple spatiotemporal
    STAs at the same time. Requires an interactive matplotlib backend.

    Parameters
    --------
    stas:
        Numpy array containing STAs. First dimension should index individual cells,
        last dimension should index time.
        Alternatively, this could be a list of numpy arrays.
    frame_duration:
      Time between each frame. (optional)
    normalize:
      Whether to normalize each STA
    cmap:
      Colormap to use.
    centerzero:
      Whether to center the colormap around zero for diverging colormaps.

    Example
    ------
    >>> print(stas.shape) # (nrcells, xpixels, ypixels, time)
    (36, 75, 100, 40)
    >>> fig, slider = stabrowser(stas, frame_duration=1/60)

    Notes
    -----
    When calling the function, the slider is returned to prevent the reference
    to it getting destroyed and to keep it interactive.
    The dummy variable `_` can also be used.
    """
    check_interactive_backend()

    if isinstance(stas, list):
        stas = np.array(stas)

    if normalize:
        stas = asc.normalize(stas)

    if cmap is None:
        cmap = iof.config('colormap')
    if centerzero:
        vmax = asc.absmax(stas)
        vmin = asc.absmin(stas)
    else:
        vmax, vmin = stas.max(), stas.min()

    imshowkwargs = dict(cmap=cmap, vmax=vmax, vmin=vmin, **kwargs)

    rows, cols = numsubplots(stas.shape[0])
    fig, axes = plt.subplots(rows, cols, sharex=True, sharey=True)

    initial_frame = 5

    axsl = fig.add_axes([0.25, 0.05, 0.65, 0.03])
    # For the slider to remain interactive, a reference to it should
    # be kept, so it set to a variable and is returned by the function
    slider_t = Slider(axsl,
                      'Frame before spike',
                      0,
                      stas.shape[-1] - 1,
                      valinit=initial_frame,
                      valstep=1,
                      valfmt='%2.0f')

    def update(frame):
        frame = int(frame)
        for i in range(rows):
            for j in range(cols):
                # Calculate the flattened index, equivalent to i*cols+j
                flat_idx = np.ravel_multi_index([i, j], (rows, cols))
                if flat_idx < stas.shape[0]:
                    im = axes[i, j].get_images()[0]
                    im.set_data(stas[flat_idx, ..., frame])
        if frame_duration is not None:
            fig.suptitle(f'{frame*frame_duration*1000:4.0f} ms')
        fig.canvas.draw_idle()

    slider_t.on_changed(update)

    for i in range(rows):
        for j in range(cols):
            flat_idx = np.ravel_multi_index([i, j], (rows, cols))
            ax = axes[i, j]
            if flat_idx < stas.shape[0]:
                ax.imshow(stas[i * cols + j, ..., initial_frame],
                          **imshowkwargs)
            ax.set_axis_off()
    plt.tight_layout()
    plt.subplots_adjust(wspace=.01, hspace=.01)
    return fig, slider_t
Esempio n. 10
0
g = sns.jointplot(
    'x',
    'y',
    sizes,
    'scatter',
    #                  shade_lowest = False,
    #                  , xlim=[0, 6], ylim=[0, 6]
)

#%%
import plotfuncs as plf

plf.absmax()  # stop the script

for i in range(len(stas)):
    sta = stas[i]
    plt.imshow(sta[..., max_inds[i][-1]],
               cmap='RdBu_r',
               vmax=asc.absmax(sta),
               vmin=asc.absmin(sta))
    drawellipse(all_pars[i])
    plt.title(f'{i}')
    plt.show()

#%%
fig, sl = plf.multistabrowser(stas)
for i in range(len(stas)):
    ax = fig.axes[i]

    drawellipse(all_pars[i], ax)