示例#1
0
from sklearn.cross_decomposition import CCA

import matplotlib.pyplot as plt

from omb import OMB
import analysis_scripts as asc
import plotfuncs as plf
from model_fitting_tools import packdims, shiftspikes, cart2pol

exp, stim_nr = '20180710_kilosorted', 8
n_components = 4

st = OMB(exp, stim_nr)
filter_length = st.filter_length

spikes = st.allspikes()
bgsteps = st.bgsteps

# Isolate a single cell
ds_cells = [8, 33, 61, 73, 79]
selected_cells = ds_cells
if len(selected_cells) == 1:
    singlecell = True
else:
    singlecell = False
spikes = spikes[selected_cells, :]
st.nclusters = len(selected_cells)
if n_components > st.nclusters:
    n_components = st.nclusters

import spikeshuffler
def omb_contrastmotion2dnonlin(exp,
                               stim,
                               nbins_nlt=9,
                               cmap='Greys',
                               plot3d=False):
    """
    Calculate and plot the 2D nonlinearities for the OMB stimulus. The
    magnitude of the stimulus projection on quadratic motion filters
    from GQM is used for the motion.

    Parameters:
    ------
        nbins_nlt:
            Number of bins to be used for dividing the generator signals
            into ranges with equal number of samples.
        plot3d:
            Whether to additionally create a 3D version of the nonlinearity.
    """

    st = OMB(exp, stim)

    # Motion and contrast
    data_cm = np.load(
        os.path.join(st.exp_dir, 'data_analysis', st.stimname,
                     'GQM_motioncontrast', f'{stim}_GQM_motioncontrast.npz'))

    qall = data_cm['Qall']
    kall = data_cm['kall']
    muall = data_cm['muall']
    cross_corrs = data_cm['cross_corrs']

    allspikes = st.allspikes()

    stim_mot = st.bgsteps.copy()

    # Bin dimension should be one greater than nonlinearity for pcolormesh
    # compatibility. Otherwise the last row and column of nonlinearity is not
    # plotted.
    all_bins_c = np.zeros((st.nclusters, nbins_nlt + 1))
    all_bins_r = np.zeros((st.nclusters, nbins_nlt + 1))
    nonlinearities = np.zeros((st.nclusters, nbins_nlt, nbins_nlt))

    label = '2D-nonlin_magQ_motion_kcontrast'

    savedir = os.path.join(st.stim_dir, label)
    os.makedirs(savedir, exist_ok=True)

    for i in range(st.nclusters):
        stim_con = st.contrast_signal_cell(i).squeeze()

        # Project the motion stimulus onto the quadratic filter
        generator_x = gqm.conv2d(qall[i, 0, :], stim_mot[0, :])
        generator_y = gqm.conv2d(qall[i, 1, :], stim_mot[1, :])

        # Calculate the magnitude of the vector formed by motion generators
        generators = np.vstack([generator_x, generator_y])
        r = np.sqrt(np.sum(generators**2, axis=0))

        # Project the contrast stimulus onto the linear filter
        generator_c = np.convolve(stim_con, kall[i, 2, :],
                                  'full')[:-st.filter_length + 1]
        spikes = allspikes[i, :]

        nonlinearity, bins_c, bins_r = nlt.calc_nonlin_2d(spikes,
                                                          generator_c,
                                                          r,
                                                          nr_bins=nbins_nlt)
        nonlinearity /= st.frame_duration

        all_bins_c[i, :] = bins_c
        all_bins_r[i, :] = bins_r
        nonlinearities[i, ...] = nonlinearity

        X, Y = np.meshgrid(bins_c, bins_r, indexing='ij')

        fig = plt.figure()

        gs = gsp.GridSpec(5, 5)
        axmain = plt.subplot(gs[1:, :-1])
        axx = plt.subplot(gs[0, :-1], sharex=axmain)
        axy = plt.subplot(gs[1:, -1], sharey=axmain)

        # Normally subplots turns off shared axis tick labels but
        # Gridspec does not do this
        plt.setp(axx.get_xticklabels(), visible=False)
        plt.setp(axy.get_yticklabels(), visible=False)

        im = axmain.pcolormesh(X, Y, nonlinearity, cmap=cmap)
        plf.integerticks(axmain)

        cb = plt.colorbar(im)
        cb.outline.set_linewidth(0)
        cb.ax.set_xlabel('spikes/s')
        cb.ax.xaxis.set_label_position('top')

        plf.integerticks(cb.ax, 4, which='y')
        plf.integerticks(axx, 1, which='y')
        plf.integerticks(axy, 1, which='x')

        barkwargs = dict(alpha=.3, facecolor='k', linewidth=.5, edgecolor='w')

        axx.bar(nlt.bin_midpoints(bins_c),
                nonlinearity.mean(axis=1),
                width=np.ediff1d(bins_c),
                **barkwargs)
        axy.barh(nlt.bin_midpoints(bins_r),
                 nonlinearity.mean(axis=0),
                 height=np.ediff1d(bins_r),
                 **barkwargs)
        plf.spineless(axx, 'b')
        plf.spineless(axy, 'l')

        axmain.set_xlabel('Projection onto linear contrast filter')
        axmain.set_ylabel(
            'Magnitude of projection onto quadratic motion filters')
        fig.suptitle(
            f'{st.exp_foldername}\n{st.stimname}\n{st.clids[i]} '
            f'2D nonlinearity nsp: {st.allspikes()[i, :].sum():<5.0f}')

        plt.subplots_adjust(top=.85)
        fig.savefig(os.path.join(savedir, st.clids[i]), bbox_inches='tight')
        plt.show()

        if plot3d:
            if i == 0:
                from mpl_toolkits import mplot3d
            from matplotlib.ticker import MaxNLocator
            #%%
            fig = plt.figure()
            ax = plt.axes(projection='3d')
            ax.plot_surface(X,
                            Y,
                            nonlinearity,
                            cmap='YlGn',
                            edgecolors='k',
                            linewidths=0.2)
            ax.set_xlabel('Projection onto linear contrast filter')
            ax.set_ylabel(
                'Magnitude of projection onto quadratic motion filters')

            ax.set_zlabel(r'Firing rate [sp/s]')
            ax.view_init(elev=30, azim=-135)

            ax.xaxis.set_major_locator(MaxNLocator(integer=True))
            ax.yaxis.set_major_locator(MaxNLocator(integer=True))
            ax.zaxis.set_major_locator(MaxNLocator(integer=True))

    keystosave = ['nonlinearities', 'all_bins_c', 'all_bins_r', 'nbins_nlt']
    datadict = {}

    for key in keystosave:
        datadict.update({key: locals()[key]})
    npzfpath = os.path.join(savedir, f'{st.stimnr}_{label}.npz')
    np.savez(npzfpath, **datadict)
示例#3
0
def cca_omb_components(exp: str,
                       stim_nr: int,
                       n_components: int = 6,
                       regularization=None,
                       filter_length=None,
                       cca_solver: str = 'macke',
                       maxframes=None,
                       shufflespikes: bool = False,
                       exclude_allzero_spike_rows: bool = True,
                       savedir: str = None,
                       savefig: bool = True,
                       sort_by_nspikes: bool = True,
                       select_cells: list = None,
                       plot_first_ncells: int = None,
                       whiten: bool = False):
    """
    Analyze OMB responses using cannonical correlation analysis and plot the results.

    Parameters
    ---
    n_components:
        Number of components that will be requested from the CCA anaylsis. More numbers mean
        the algortihm will stop at a later point. That means components of analyses with fewer
        n_components are going to be identical to the first n components of the higher-number
        component analyses.
    regularization:
        The regularization parameter to be passed onto rcca.CCA. Not relevant for macke
    filter_length:
        The length of the time window to be considered in the past for the stimulus and the responses.
        Can be different for stimulus and response, if a tuple is given.
    cca_solver:
        Which CCA solver to use. Options are `rcca` and `macke`(default)
    maxframes: int
        Number of frames to load in the the experiment object. Used to avoid memory and performance
        issues.
    shufflespikes: bool
        Whether to randomize the spikes, to validate the results
    exclude_allzero_spike_rows:
        Exclude all cells which have zero spikes for the duration of the stimulus.
    savedir: str
        Custom directory to save the figures and data files. If None, will be saved in the experiment
        directory under appropritate path.
    savefig: bool
        Whether to save the figures.
    sort_by_nspikes: bool
        Wheter to sort the cell weights array by the number of spikes during the stimulus.
    select_cells: list
       A list of indexes for the subset of cells to perform the analysis for.
    plot_first_ncells: int
        Number of cells to plot in the cell plots.
    """
    if regularization is None:
        regularization = 0

    st = OMB(exp, stim_nr, maxframes=maxframes)

    if filter_length is None:
        filter_length = st.filter_length

    if type(filter_length) is int:
        filter_length = (filter_length, filter_length)

    if type(savedir) is str:
        savedir = Path(savedir)

    if savedir is None:
        savedir = st.stim_dir / 'CCA'
        savedir.mkdir(exist_ok=True, parents=True)

    spikes = st.allspikes()
    nonzerospikerows = ~np.isclose(spikes.sum(axis=1), 0)
    # Set the mean to zero for spikes
    spikes -= spikes.mean(axis=1)[:, None]

    bgsteps = st.bgsteps

    if select_cells is not None:
        if type(select_cells) is not np.array:
            select_cells = np.array(select_cells)
        spikes = spikes[select_cells]
        st.nclusters = len(select_cells)
        # Convert to list for better string representation
        # np.array is printed as "array([....])"
        # with newline characters which is problematic in filenames
        select_cells = list(select_cells)

    # Exclude rows that have no spikes throughout
    if exclude_allzero_spike_rows:
        spikes = spikes[nonzerospikerows, :]

    nspikes_percell = spikes.sum(axis=1)

    if shufflespikes:
        spikes = spikeshuffler.shufflebyrow(spikes)

    figsavename = f'{n_components=}_{shufflespikes=}_{select_cells=}_{regularization=}_{filter_length=}_{whiten=}'
    # If the file length gets too long due to the list of selected cells, summarize it.
    if len(figsavename) > 200:
        figsavename = f'{n_components=}_{shufflespikes=}_select_cells={len(select_cells)}cells-index{select_cells[0]}to{select_cells[-1]}_{regularization=}_{filter_length=}_{whiten=}'

    #sp_train, sp_test, stim_train, stim_test = train_test_split(spikes, bgsteps)

    stimulus = mft.packdims(st.bgsteps, filter_length[0])
    spikes = mft.packdims(spikes, filter_length[1])

    if cca_solver == 'rcca':
        resp_comps, stim_comps, cancorrs = cca_rcca(spikes, stimulus,
                                                    filter_length,
                                                    n_components,
                                                    regularization, whiten)

        # cells = np.swapaxes(spikes_res, 1, 0)
        # cells = cells.reshape((n_components, st.nclusters, filter_length[1]))
        # motionfilt_x = cca.ws[1][:filter_length[0]].T
        # motionfilt_y = cca.ws[1][filter_length[0]:].T
    elif cca_solver == 'macke':
        resp_comps, stim_comps, cancorrs = cca_macke(spikes, stimulus,
                                                     filter_length,
                                                     n_components)

    nsp_argsorted = np.argsort(nspikes_percell)
    resp_comps_sorted_nsp = resp_comps[:, nsp_argsorted, :]

    if sort_by_nspikes:
        resp_comps_toplot = resp_comps_sorted_nsp
    else:
        resp_comps_toplot = resp_comps

    if plot_first_ncells is not None:
        resp_comps_toplot = resp_comps_toplot[:, :plot_first_ncells, ...]

    motionfilt_r, motionfilt_theta = mft.cart2pol(stim_comps[:, 0, :],
                                                  stim_comps[:, 1, :])
    #%%
    nrows, ncols = plf.numsubplots(n_components)
    fig_cells, axes_cells = plt.subplots(nrows, ncols, figsize=(10, 10))

    for i in range(n_components):
        ax = axes_cells.flat[i]
        im = ax.imshow(resp_comps[i, :],
                       cmap='RdBu_r',
                       vmin=asc.absmin(resp_comps),
                       vmax=asc.absmax(resp_comps),
                       aspect='auto',
                       interpolation='nearest')
        ax.set_title(f'{i}')
    fig_cells.suptitle(f'Cells default order {shufflespikes=}')
    if savefig:
        fig_cells.savefig(savedir / f'{figsavename}_cells_default_order.pdf')
    plt.close(fig_cells)

    nsubplots = plf.numsubplots(n_components)
    height_list = [1, 1, 1, 3]  # ratios of the plots in each component

    # Create a time vector for the stimulus plots
    t_stim = -np.arange(0, filter_length[0] * st.frame_duration,
                        st.frame_duration)[::-1] * 1000
    t_response = -np.arange(0, filter_length[1] * st.frame_duration,
                            st.frame_duration)[::-1] * 1000
    xtick_loc_params = dict(nbins=4, steps=[2, 5, 10], integer=True)

    nsubrows = len(height_list)
    height_ratios = nsubplots[0] * height_list
    fig, axes = plt.subplots(nrows=nsubplots[0] * nsubrows,
                             ncols=nsubplots[1],
                             gridspec_kw={'height_ratios': height_ratios},
                             figsize=(11, 10))

    for row, ax_row in enumerate(axes):
        for col, ax in enumerate(ax_row):
            mode_i = int(row / nsubrows) * nsubplots[1] + col
            # ax.text(0.5, 0.5, f'{mode_i}')
            ax.set_yticks([])
            # Plot motion filters
            if row % nsubrows == 0:

                ax.plot(t_stim,
                        stim_comps[mode_i, 0, :],
                        marker='o',
                        markersize=1)
                ax.plot(t_stim,
                        stim_comps[mode_i, 1, :],
                        marker='o',
                        markersize=1)
                if col == 0:
                    ax.set_ylabel('Motion',
                                  rotation=0,
                                  ha='right',
                                  va='center')
                ax.set_ylim(stim_comps.min(), stim_comps.max())

                # Draw a horizontal line for zero and prevent rescaling of x-axis
                xlims = ax.get_xlim()
                ax.hlines(0,
                          *ax.get_xlim(),
                          colors='k',
                          linestyles='dashed',
                          alpha=0.3)
                ax.set_xlim(*xlims)

                # ax.set_title(f'Component {mode_i}', fontweight='bold')

                ax.xaxis.set_major_locator(MaxNLocator(**xtick_loc_params))

                if not mode_i == 0 or filter_length[0] == filter_length[1]:
                    ax.xaxis.set_ticklabels([])
                else:
                    ax.tick_params(axis='x', labelsize=8)

            # Plot magnitude of motion
            elif row % nsubrows == 1:
                ax.plot(t_stim,
                        motionfilt_r[mode_i, :],
                        color='k',
                        marker='o',
                        markersize=1)
                if col == 0:
                    ax.set_ylabel('Magnitude',
                                  rotation=0,
                                  ha='right',
                                  va='center')
                ax.set_ylim(motionfilt_r.min(), motionfilt_r.max())
                ax.xaxis.set_ticklabels([])
                ax.xaxis.set_major_locator(MaxNLocator(**xtick_loc_params))
            # Plot direction of motion
            elif row % nsubrows == 2:
                ax.plot(t_stim,
                        motionfilt_theta[mode_i, :],
                        color='r',
                        marker='o',
                        markersize=1)
                if mode_i == 0:
                    ax.yaxis.set_ticks([-np.pi, 0, np.pi])
                    ax.yaxis.set_ticklabels(['-π', 0, 'π'])
                ax.xaxis.set_ticklabels([])
                ax.xaxis.set_major_locator(MaxNLocator(**xtick_loc_params))
            # Plot cell weights
            elif row % nsubrows == nsubrows - 1:
                im = ax.imshow(resp_comps_toplot[mode_i, :],
                               cmap='RdBu_r',
                               vmin=asc.absmin(resp_comps),
                               vmax=asc.absmax(resp_comps),
                               aspect='auto',
                               interpolation='nearest',
                               extent=[
                                   t_response[0], t_response[-1], 0,
                                   resp_comps_toplot.shape[1]
                               ])
                ax.xaxis.set_major_locator(MaxNLocator(**xtick_loc_params))
                if row == axes.shape[0] - 1:
                    ax.set_xlabel('Time before spike [ms]')
                    # ax.set_xticks(np.array([0, .25, .5, .75, 1]) * cells_toplot.shape[-1])
                    # ax.xaxis.set_ticklabels(-np.round((ax.get_xticks()*st.frame_duration), 2)[::-1])
                else:
                    ax.xaxis.set_ticklabels([])

                plf.integerticks(ax, 5, which='y')
                if col == 0:
                    ax.set_ylabel(
                        f'Cells\n{"(sorted nsp)"*sort_by_nspikes}\n{("(first " + str(plot_first_ncells)+ " cells)")*(type(plot_first_ncells) is int) }',
                        rotation=0,
                        ha='right',
                        va='center')
                else:
                    ax.yaxis.set_ticklabels([])
                if mode_i == n_components - 1:
                    plf.colorbar(im)
            # Add ticks on the right side of the plots
            if col == nsubplots[1] - 1 and row % nsubrows != nsubrows - 1:
                plf.integerticks(ax, 3, which='y')
                ax.yaxis.tick_right()

    fig.suptitle(
        f'CCA components of {st.exp_foldername}\n{shufflespikes=} {n_components=}\n{sort_by_nspikes=}\n'
        + f'{select_cells=} {regularization=} {filter_length=}')
    fig.subplots_adjust(wspace=0.1, hspace=0.3)
    if savefig:
        fig.savefig(savedir / f'{figsavename}_cellsandcomponents.pdf')
    # plt.show()
    plt.close(fig)

    #%%
    fig_corrs = plt.figure()
    plt.plot(cancorrs, 'ko')
    # plt.ylim([0.17, 0.24])
    plt.xlabel('Component index')
    plt.ylabel('Correlation')
    plt.title(f'Cannonical correlations {shufflespikes=}')
    if savefig:
        fig_corrs.savefig(savedir / f'{figsavename}_correlation_coeffs.pdf')
    # plt.show()
    plt.close(fig_corrs)

    fig_nlt, axes_nlt = plt.subplots(nrows, ncols, figsize=(10, 10))

    stim_comps_flatter = stim_comps[:n_components].reshape(
        (n_components, 2 * filter_length[0])).T
    resp_comps_flatter = resp_comps[:n_components].reshape(
        (n_components, resp_comps.shape[1] * filter_length[1])).T

    # from IPython.core.debugger import Pdb; ipdb=Pdb(); ipdb.set_trace()
    # Reshape to perform the convolution as a matrix multiplication
    generator_stim = stimulus @ stim_comps_flatter
    generator_resp = spikes @ resp_comps_flatter

    for i, ax in enumerate(axes_nlt.flatten()):

        nonlinearity, bins = nlt.calc_nonlin(generator_resp[:, i],
                                             generator_stim[:, i])
        # ax.scatter(generator_stim, generator_resp, s=1, alpha=0.5, facecolor='grey')
        ax.plot(bins, nonlinearity, 'k')
        if i == 0:
            all_nonlinearities = np.empty((n_components, *nonlinearity.shape))
            all_bins = np.empty((n_components, *bins.shape))
        all_nonlinearities[i, ...] = nonlinearity
        all_bins[i, ...] = bins

    nlt_xlims = []
    nlt_ylims = []
    for i, ax in enumerate(axes_nlt.flatten()):
        xlim = ax.get_xlim()
        ylim = ax.get_ylim()

        nlt_xlims.extend(xlim)
        nlt_ylims.extend(ylim)
    nlt_maxx, nlt_minx = max(nlt_xlims), min(nlt_xlims)
    nlt_maxy, nlt_miny = max(nlt_ylims), min(nlt_ylims)

    for i, ax in enumerate(axes_nlt.flatten()):
        ax.set_xlim([nlt_minx, nlt_maxx])
        ax.set_ylim([nlt_miny, nlt_maxy])

    for i, axes_row in enumerate(axes_nlt):
        for j, ax in enumerate(axes_row):
            if i == nrows - 1:
                ax.set_xlabel('Generator (motion)')
            if j == 0:
                ax.set_ylabel('Generator (cells)')
            else:
                ax.yaxis.set_ticklabels([])
            ax.set_xlim([nlt_minx, nlt_maxx])
            ax.set_ylim([nlt_miny, nlt_maxy])

    fig_nlt.suptitle(f'Nonlinearities\n{figsavename}')
    if savefig:
        fig_nlt.savefig(savedir / f'{figsavename}_nonlinearity.png')
    plt.close(fig_nlt)
    keystosave = [
        'n_components', 'resp_comps', 'stim_comps', 'motionfilt_r',
        'motionfilt_theta', 'resp_comps_sorted_nsp', 'select_cells',
        'regularization', 'filter_length', 'all_nonlinearities', 'all_bins',
        'cca_solver'
    ]
    datadict = dict()
    for key in keystosave:
        datadict[key] = locals()[key]
    np.savez(savedir / figsavename, **datadict)
# Only contrast
data_c = np.load(
    os.path.join(st.exp_dir, 'data_analysis', st.stimname, gqmlabels[0],
                 f'{stim}_{gqmlabels[0]}.npz'))
# only Motion
data_m = np.load(
    os.path.join(st.exp_dir, 'data_analysis', st.stimname, gqmlabels[1],
                 f'{stim}_{gqmlabels[1]}.npz'))
# Motion and contrast
data_cm = np.load(
    os.path.join(st.exp_dir, 'data_analysis', st.stimname, gqmlabels[2],
                 f'{stim}_{gqmlabels[2]}.npz'))

# Exclude those with very few spikes
cutoff = 0.1  # In units of spikes/s
lowq = (st.allspikes().mean(axis=1) / st.frame_duration) < cutoff

cc_c = data_c['cross_corrs'][~lowq]
cc_m = data_m['cross_corrs'][~lowq]
cc_cm = data_cm['cross_corrs'][~lowq]

#%% Scatter
fig, axes = plt.subplots(
    2,
    2,
    figsize=(5.5, 5),
    #                         sharex=True, sharey=True
)

ax1, ax2, ax3, ax4 = axes.flat
    return calculate_loglikelihood(kmu, spikes, stimulus,
                                   time_res) - calculate_ll0(spikes)


if __name__ == '__main__':
    import matplotlib.pyplot as plt
    from omb import OMB
    import genlinmod_multidimensional as glmm
    import plotfuncs as plf
    #    from driftinggratings import DriftingGratings

    exp, stim = '20180710', 8
    #    exp, stim = 'Kuehn', 13
    st = OMB(exp, stim)
    species = st.metadata["animal"]
    allspikes = st.allspikes()

    data_cm = np.load(
        f'{st.stim_dir}/GLM_motioncontrast_xval/{st.stimnr}_GLM_motioncontrast_xval.npz'
    )
    data_c = np.load(
        f'{st.stim_dir}/GLM_contrast_xval/{st.stimnr}_GLM_contrast_xval.npz')
    data_m = np.load(
        f'{st.stim_dir}/GLM_motion_xval/{st.stimnr}_GLM_motion_xval.npz')

    model_input = [('Contrast and motion', 3), ('Contrast', 1), ('Motion', 2)]

    logls = np.zeros((st.nclusters, 3))

    # Exclude those with very few spikes
    cutoff = 0.2  # In units of spikes/s
示例#6
0
import matplotlib.pyplot as plt

from omb import OMB
import analysis_scripts as asc
import plotfuncs as plf

from model_fitting_tools import packdims, shiftspikes, cart2pol

exp, stim_nr = '20180710*kilosorted', 8
n_components = 6

st = OMB(exp, stim_nr)
filter_length = st.filter_length

spikes = st.allspikes()
bgsteps = st.bgsteps

for shift in [0]:
    print(shift)

    spikes = shiftspikes(st.allspikes(), shift)

    stimulus = packdims(st.bgsteps, filter_length)
    spikes = packdims(spikes, filter_length)

    cca = CCA(n_components=n_components,
              scale=True,
              max_iter=500,
              tol=1e-06,
              copy=True)
def omb_contrastmotion2dnonlin_Qcomps(exp, stim, nbins_nlt=9, cmap='Greys'):
    """
    Calculate and plot the 2D nonlinearities for the OMB stimulus. Multiple
    components of the matrix Q for the motion.

    Parameters:
    ------
        nbins_nlt:
            Number of bins to be used for dividing the generator signals
            into ranges with equal number of samples.
    """

    st = OMB(exp, stim)

    # Motion and contrast
    data_cm = np.load(os.path.join(st.exp_dir, 'data_analysis',
                                   st.stimname, 'GQM_motioncontrast_val',
                                   f'{stim}_GQM_motioncontrast_val.npz'))

    qall = data_cm['Qall']
    kall = data_cm['kall']
    muall = data_cm['muall']

    eigvecs = data_cm['eigvecs']
    eigvals = data_cm['eigvals']

    eiginds = [-1, 0]  # activating, suppressing #HINT

    cross_corrs = data_cm['cross_corrs']

    allspikes = st.allspikes()

    stim_mot = st.bgsteps.copy()

    # Bin dimension should be one greater than nonlinearity for pcolormesh
    # compatibility. Otherwise the last row and column of nonlinearity is not
    # plotted.
    all_bins_c = np.zeros((st.nclusters, nbins_nlt+1))
    all_bins_r = np.zeros((st.nclusters, nbins_nlt+1))
    nonlinearities = np.zeros((st.nclusters, nbins_nlt, nbins_nlt))

    label = '2D-nonlin_Qallcomps_motion_kcontrast'

    row_labels = ['Activating', 'Suppresive']
    column_labels = ['X', 'Y', r'$\sqrt{X^2 + Y^2}$']

    savedir = os.path.join(st.stim_dir, label)
    os.makedirs(savedir, exist_ok=True)

    for i in range(st.nclusters):
        stim_con = st.contrast_signal_cell(i).squeeze()

        n = 3  # x, y, xy
        m = 2  # activating, suppressing
        fig = plt.figure(figsize=(n*5, m*5), constrained_layout=True)
        gs = fig.add_gridspec(m, n)
        axes = []
        for _, eachgs in enumerate(gs):
            subgs = eachgs.subgridspec(2, 3, width_ratios=[4, 1, .2], height_ratios=[1, 4])
            mainax = fig.add_subplot(subgs[1, 0])
            axx = fig.add_subplot(subgs[0, 0], sharex=mainax)
            axy = fig.add_subplot(subgs[1, 1], sharey=mainax)
            cbax = fig.add_subplot(subgs[1, 2])
            axes.append([axx, mainax, axy, cbax])

        for k, eigind in enumerate(eiginds):
            generator_x = np.convolve(eigvecs[i, 0, :, eigind],
                                      stim_mot[0, :], 'full')[:-st.filter_length+1]
            generator_y = np.convolve(eigvecs[i, 1, :, eigind],
                                      stim_mot[1, :], 'full')[:-st.filter_length+1]

            # Calculate the magnitude of the vector formed by motion generators
            generators = np.vstack([generator_x, generator_y])
            generator_xy = np.sqrt(np.sum(generators**2, axis=0))

            # Project the contrast stimulus onto the linear filter
            generator_c = np.convolve(stim_con,
                                      kall[i, 2, :],
                                      'full')[:-st.filter_length+1]
            spikes = allspikes[i, :]

            generators_motion = [generator_x, generator_y, generator_xy]

            for l, direction in enumerate(column_labels):
                nonlinearity, bins_c, bins_r = nlt.calc_nonlin_2d(spikes,
                                                                  generator_c,
                                                                  generators_motion[l],
                                                                  nr_bins=nbins_nlt)
                nonlinearity /= st.frame_duration

                all_bins_c[i, :] = bins_c
                all_bins_r[i, :] = bins_r
                nonlinearities[i, ...] = nonlinearity

                X, Y = np.meshgrid(bins_c, bins_r, indexing='ij')

                subaxes = axes[k*n+l]

                axmain = subaxes[1]
                axx = subaxes[0]
                axy = subaxes[2]
                cbax = subaxes[3]

                # Normally subplots turns off shared axis tick labels but
                # Gridspec does not do this
                plt.setp(axx.get_xticklabels(), visible=False)
                plt.setp(axy.get_yticklabels(), visible=False)

                im = axmain.pcolormesh(X, Y, nonlinearity, cmap=cmap)
                plf.integerticks(axmain, 6, which='xy')

                cb = plt.colorbar(im, cax=cbax)
                cb.outline.set_linewidth(0)
                cb.ax.set_xlabel('spikes/s')
                cb.ax.xaxis.set_label_position('top')

                plf.integerticks(cb.ax, 4, which='y')
                plf.integerticks(axx, 1, which='y')
                plf.integerticks(axy, 1, which='x')

                barkwargs = dict(alpha=.3, facecolor='k',
                                 linewidth=.5, edgecolor='w')

                axx.bar(nlt.bin_midpoints(bins_c), nonlinearity.mean(axis=1),
                        width=np.ediff1d(bins_c), **barkwargs)
                axy.barh(nlt.bin_midpoints(bins_r), nonlinearity.mean(axis=0),
                         height=np.ediff1d(bins_r), **barkwargs)
                plf.spineless(axx, 'b')
                plf.spineless(axy, 'l')

                if k == 0 and l == 0:
                    axmain.set_xlabel('Projection onto linear contrast filter')
                    axmain.set_ylabel(f'Projection onto Q component')
                if k == 0:
                    axx.set_title(direction)
                if l == 0:
                    axmain.text(-.3, .5, row_labels[k],
                                va='center',
                                rotation=90,
                                transform=axmain.transAxes)

        fig.suptitle(f'{st.exp_foldername}\n{st.stimname}\n{st.clids[i]} '
                     f'2D nonlinearity nsp: {st.allspikes()[i, :].sum():<5.0f}')

        plt.subplots_adjust(top=.85)
        fig.savefig(os.path.join(savedir, st.clids[i]), bbox_inches='tight')
        plt.show()

    keystosave = ['nonlinearities', 'all_bins_c', 'all_bins_r', 'nbins_nlt']
    datadict = {}

    for key in keystosave:
        datadict.update({key: locals()[key]})
    npzfpath = os.path.join(savedir, f'{st.stimnr}_{label}.npz')
    np.savez(npzfpath, **datadict)
    >>> contrast_avg = contrast.mean(axis=-1)
    >>> sta_corrected = subtract_avgcontrast(stas, contrast_avg)
    """
    return sta - contrast_avg[None, :, :, None]


exp, ombstimnr = '20180710_kilosorted', 8
checkerstimnr = 6

st = OMB(exp, ombstimnr, maxframes=None)

choosecells = [8, 33, 61, 73, 79]

nrcells = len(choosecells)

all_spikes = st.allspikes()[choosecells, :]

rw = asc.rolling_window(st.bgsteps, st.filter_length)

motionstas = np.array(st.read_datafile()['stas'])[choosecells, :]
motionstas /= all_spikes.sum(axis=(-1))[:, np.newaxis, np.newaxis]

#%% Filter the stimuli

# Euclidian norm
motionstas_norm = motionstas / np.sqrt(
    (motionstas**2).sum(axis=-1))[:, :, None]

bgsteps = st.bgsteps / np.sqrt(st.bgsteps.var())
rw = asc.rolling_window(bgsteps, st.filter_length)