from sklearn.cross_decomposition import CCA import matplotlib.pyplot as plt from omb import OMB import analysis_scripts as asc import plotfuncs as plf from model_fitting_tools import packdims, shiftspikes, cart2pol exp, stim_nr = '20180710_kilosorted', 8 n_components = 4 st = OMB(exp, stim_nr) filter_length = st.filter_length spikes = st.allspikes() bgsteps = st.bgsteps # Isolate a single cell ds_cells = [8, 33, 61, 73, 79] selected_cells = ds_cells if len(selected_cells) == 1: singlecell = True else: singlecell = False spikes = spikes[selected_cells, :] st.nclusters = len(selected_cells) if n_components > st.nclusters: n_components = st.nclusters import spikeshuffler
def omb_contrastmotion2dnonlin(exp, stim, nbins_nlt=9, cmap='Greys', plot3d=False): """ Calculate and plot the 2D nonlinearities for the OMB stimulus. The magnitude of the stimulus projection on quadratic motion filters from GQM is used for the motion. Parameters: ------ nbins_nlt: Number of bins to be used for dividing the generator signals into ranges with equal number of samples. plot3d: Whether to additionally create a 3D version of the nonlinearity. """ st = OMB(exp, stim) # Motion and contrast data_cm = np.load( os.path.join(st.exp_dir, 'data_analysis', st.stimname, 'GQM_motioncontrast', f'{stim}_GQM_motioncontrast.npz')) qall = data_cm['Qall'] kall = data_cm['kall'] muall = data_cm['muall'] cross_corrs = data_cm['cross_corrs'] allspikes = st.allspikes() stim_mot = st.bgsteps.copy() # Bin dimension should be one greater than nonlinearity for pcolormesh # compatibility. Otherwise the last row and column of nonlinearity is not # plotted. all_bins_c = np.zeros((st.nclusters, nbins_nlt + 1)) all_bins_r = np.zeros((st.nclusters, nbins_nlt + 1)) nonlinearities = np.zeros((st.nclusters, nbins_nlt, nbins_nlt)) label = '2D-nonlin_magQ_motion_kcontrast' savedir = os.path.join(st.stim_dir, label) os.makedirs(savedir, exist_ok=True) for i in range(st.nclusters): stim_con = st.contrast_signal_cell(i).squeeze() # Project the motion stimulus onto the quadratic filter generator_x = gqm.conv2d(qall[i, 0, :], stim_mot[0, :]) generator_y = gqm.conv2d(qall[i, 1, :], stim_mot[1, :]) # Calculate the magnitude of the vector formed by motion generators generators = np.vstack([generator_x, generator_y]) r = np.sqrt(np.sum(generators**2, axis=0)) # Project the contrast stimulus onto the linear filter generator_c = np.convolve(stim_con, kall[i, 2, :], 'full')[:-st.filter_length + 1] spikes = allspikes[i, :] nonlinearity, bins_c, bins_r = nlt.calc_nonlin_2d(spikes, generator_c, r, nr_bins=nbins_nlt) nonlinearity /= st.frame_duration all_bins_c[i, :] = bins_c all_bins_r[i, :] = bins_r nonlinearities[i, ...] = nonlinearity X, Y = np.meshgrid(bins_c, bins_r, indexing='ij') fig = plt.figure() gs = gsp.GridSpec(5, 5) axmain = plt.subplot(gs[1:, :-1]) axx = plt.subplot(gs[0, :-1], sharex=axmain) axy = plt.subplot(gs[1:, -1], sharey=axmain) # Normally subplots turns off shared axis tick labels but # Gridspec does not do this plt.setp(axx.get_xticklabels(), visible=False) plt.setp(axy.get_yticklabels(), visible=False) im = axmain.pcolormesh(X, Y, nonlinearity, cmap=cmap) plf.integerticks(axmain) cb = plt.colorbar(im) cb.outline.set_linewidth(0) cb.ax.set_xlabel('spikes/s') cb.ax.xaxis.set_label_position('top') plf.integerticks(cb.ax, 4, which='y') plf.integerticks(axx, 1, which='y') plf.integerticks(axy, 1, which='x') barkwargs = dict(alpha=.3, facecolor='k', linewidth=.5, edgecolor='w') axx.bar(nlt.bin_midpoints(bins_c), nonlinearity.mean(axis=1), width=np.ediff1d(bins_c), **barkwargs) axy.barh(nlt.bin_midpoints(bins_r), nonlinearity.mean(axis=0), height=np.ediff1d(bins_r), **barkwargs) plf.spineless(axx, 'b') plf.spineless(axy, 'l') axmain.set_xlabel('Projection onto linear contrast filter') axmain.set_ylabel( 'Magnitude of projection onto quadratic motion filters') fig.suptitle( f'{st.exp_foldername}\n{st.stimname}\n{st.clids[i]} ' f'2D nonlinearity nsp: {st.allspikes()[i, :].sum():<5.0f}') plt.subplots_adjust(top=.85) fig.savefig(os.path.join(savedir, st.clids[i]), bbox_inches='tight') plt.show() if plot3d: if i == 0: from mpl_toolkits import mplot3d from matplotlib.ticker import MaxNLocator #%% fig = plt.figure() ax = plt.axes(projection='3d') ax.plot_surface(X, Y, nonlinearity, cmap='YlGn', edgecolors='k', linewidths=0.2) ax.set_xlabel('Projection onto linear contrast filter') ax.set_ylabel( 'Magnitude of projection onto quadratic motion filters') ax.set_zlabel(r'Firing rate [sp/s]') ax.view_init(elev=30, azim=-135) ax.xaxis.set_major_locator(MaxNLocator(integer=True)) ax.yaxis.set_major_locator(MaxNLocator(integer=True)) ax.zaxis.set_major_locator(MaxNLocator(integer=True)) keystosave = ['nonlinearities', 'all_bins_c', 'all_bins_r', 'nbins_nlt'] datadict = {} for key in keystosave: datadict.update({key: locals()[key]}) npzfpath = os.path.join(savedir, f'{st.stimnr}_{label}.npz') np.savez(npzfpath, **datadict)
def cca_omb_components(exp: str, stim_nr: int, n_components: int = 6, regularization=None, filter_length=None, cca_solver: str = 'macke', maxframes=None, shufflespikes: bool = False, exclude_allzero_spike_rows: bool = True, savedir: str = None, savefig: bool = True, sort_by_nspikes: bool = True, select_cells: list = None, plot_first_ncells: int = None, whiten: bool = False): """ Analyze OMB responses using cannonical correlation analysis and plot the results. Parameters --- n_components: Number of components that will be requested from the CCA anaylsis. More numbers mean the algortihm will stop at a later point. That means components of analyses with fewer n_components are going to be identical to the first n components of the higher-number component analyses. regularization: The regularization parameter to be passed onto rcca.CCA. Not relevant for macke filter_length: The length of the time window to be considered in the past for the stimulus and the responses. Can be different for stimulus and response, if a tuple is given. cca_solver: Which CCA solver to use. Options are `rcca` and `macke`(default) maxframes: int Number of frames to load in the the experiment object. Used to avoid memory and performance issues. shufflespikes: bool Whether to randomize the spikes, to validate the results exclude_allzero_spike_rows: Exclude all cells which have zero spikes for the duration of the stimulus. savedir: str Custom directory to save the figures and data files. If None, will be saved in the experiment directory under appropritate path. savefig: bool Whether to save the figures. sort_by_nspikes: bool Wheter to sort the cell weights array by the number of spikes during the stimulus. select_cells: list A list of indexes for the subset of cells to perform the analysis for. plot_first_ncells: int Number of cells to plot in the cell plots. """ if regularization is None: regularization = 0 st = OMB(exp, stim_nr, maxframes=maxframes) if filter_length is None: filter_length = st.filter_length if type(filter_length) is int: filter_length = (filter_length, filter_length) if type(savedir) is str: savedir = Path(savedir) if savedir is None: savedir = st.stim_dir / 'CCA' savedir.mkdir(exist_ok=True, parents=True) spikes = st.allspikes() nonzerospikerows = ~np.isclose(spikes.sum(axis=1), 0) # Set the mean to zero for spikes spikes -= spikes.mean(axis=1)[:, None] bgsteps = st.bgsteps if select_cells is not None: if type(select_cells) is not np.array: select_cells = np.array(select_cells) spikes = spikes[select_cells] st.nclusters = len(select_cells) # Convert to list for better string representation # np.array is printed as "array([....])" # with newline characters which is problematic in filenames select_cells = list(select_cells) # Exclude rows that have no spikes throughout if exclude_allzero_spike_rows: spikes = spikes[nonzerospikerows, :] nspikes_percell = spikes.sum(axis=1) if shufflespikes: spikes = spikeshuffler.shufflebyrow(spikes) figsavename = f'{n_components=}_{shufflespikes=}_{select_cells=}_{regularization=}_{filter_length=}_{whiten=}' # If the file length gets too long due to the list of selected cells, summarize it. if len(figsavename) > 200: figsavename = f'{n_components=}_{shufflespikes=}_select_cells={len(select_cells)}cells-index{select_cells[0]}to{select_cells[-1]}_{regularization=}_{filter_length=}_{whiten=}' #sp_train, sp_test, stim_train, stim_test = train_test_split(spikes, bgsteps) stimulus = mft.packdims(st.bgsteps, filter_length[0]) spikes = mft.packdims(spikes, filter_length[1]) if cca_solver == 'rcca': resp_comps, stim_comps, cancorrs = cca_rcca(spikes, stimulus, filter_length, n_components, regularization, whiten) # cells = np.swapaxes(spikes_res, 1, 0) # cells = cells.reshape((n_components, st.nclusters, filter_length[1])) # motionfilt_x = cca.ws[1][:filter_length[0]].T # motionfilt_y = cca.ws[1][filter_length[0]:].T elif cca_solver == 'macke': resp_comps, stim_comps, cancorrs = cca_macke(spikes, stimulus, filter_length, n_components) nsp_argsorted = np.argsort(nspikes_percell) resp_comps_sorted_nsp = resp_comps[:, nsp_argsorted, :] if sort_by_nspikes: resp_comps_toplot = resp_comps_sorted_nsp else: resp_comps_toplot = resp_comps if plot_first_ncells is not None: resp_comps_toplot = resp_comps_toplot[:, :plot_first_ncells, ...] motionfilt_r, motionfilt_theta = mft.cart2pol(stim_comps[:, 0, :], stim_comps[:, 1, :]) #%% nrows, ncols = plf.numsubplots(n_components) fig_cells, axes_cells = plt.subplots(nrows, ncols, figsize=(10, 10)) for i in range(n_components): ax = axes_cells.flat[i] im = ax.imshow(resp_comps[i, :], cmap='RdBu_r', vmin=asc.absmin(resp_comps), vmax=asc.absmax(resp_comps), aspect='auto', interpolation='nearest') ax.set_title(f'{i}') fig_cells.suptitle(f'Cells default order {shufflespikes=}') if savefig: fig_cells.savefig(savedir / f'{figsavename}_cells_default_order.pdf') plt.close(fig_cells) nsubplots = plf.numsubplots(n_components) height_list = [1, 1, 1, 3] # ratios of the plots in each component # Create a time vector for the stimulus plots t_stim = -np.arange(0, filter_length[0] * st.frame_duration, st.frame_duration)[::-1] * 1000 t_response = -np.arange(0, filter_length[1] * st.frame_duration, st.frame_duration)[::-1] * 1000 xtick_loc_params = dict(nbins=4, steps=[2, 5, 10], integer=True) nsubrows = len(height_list) height_ratios = nsubplots[0] * height_list fig, axes = plt.subplots(nrows=nsubplots[0] * nsubrows, ncols=nsubplots[1], gridspec_kw={'height_ratios': height_ratios}, figsize=(11, 10)) for row, ax_row in enumerate(axes): for col, ax in enumerate(ax_row): mode_i = int(row / nsubrows) * nsubplots[1] + col # ax.text(0.5, 0.5, f'{mode_i}') ax.set_yticks([]) # Plot motion filters if row % nsubrows == 0: ax.plot(t_stim, stim_comps[mode_i, 0, :], marker='o', markersize=1) ax.plot(t_stim, stim_comps[mode_i, 1, :], marker='o', markersize=1) if col == 0: ax.set_ylabel('Motion', rotation=0, ha='right', va='center') ax.set_ylim(stim_comps.min(), stim_comps.max()) # Draw a horizontal line for zero and prevent rescaling of x-axis xlims = ax.get_xlim() ax.hlines(0, *ax.get_xlim(), colors='k', linestyles='dashed', alpha=0.3) ax.set_xlim(*xlims) # ax.set_title(f'Component {mode_i}', fontweight='bold') ax.xaxis.set_major_locator(MaxNLocator(**xtick_loc_params)) if not mode_i == 0 or filter_length[0] == filter_length[1]: ax.xaxis.set_ticklabels([]) else: ax.tick_params(axis='x', labelsize=8) # Plot magnitude of motion elif row % nsubrows == 1: ax.plot(t_stim, motionfilt_r[mode_i, :], color='k', marker='o', markersize=1) if col == 0: ax.set_ylabel('Magnitude', rotation=0, ha='right', va='center') ax.set_ylim(motionfilt_r.min(), motionfilt_r.max()) ax.xaxis.set_ticklabels([]) ax.xaxis.set_major_locator(MaxNLocator(**xtick_loc_params)) # Plot direction of motion elif row % nsubrows == 2: ax.plot(t_stim, motionfilt_theta[mode_i, :], color='r', marker='o', markersize=1) if mode_i == 0: ax.yaxis.set_ticks([-np.pi, 0, np.pi]) ax.yaxis.set_ticklabels(['-π', 0, 'π']) ax.xaxis.set_ticklabels([]) ax.xaxis.set_major_locator(MaxNLocator(**xtick_loc_params)) # Plot cell weights elif row % nsubrows == nsubrows - 1: im = ax.imshow(resp_comps_toplot[mode_i, :], cmap='RdBu_r', vmin=asc.absmin(resp_comps), vmax=asc.absmax(resp_comps), aspect='auto', interpolation='nearest', extent=[ t_response[0], t_response[-1], 0, resp_comps_toplot.shape[1] ]) ax.xaxis.set_major_locator(MaxNLocator(**xtick_loc_params)) if row == axes.shape[0] - 1: ax.set_xlabel('Time before spike [ms]') # ax.set_xticks(np.array([0, .25, .5, .75, 1]) * cells_toplot.shape[-1]) # ax.xaxis.set_ticklabels(-np.round((ax.get_xticks()*st.frame_duration), 2)[::-1]) else: ax.xaxis.set_ticklabels([]) plf.integerticks(ax, 5, which='y') if col == 0: ax.set_ylabel( f'Cells\n{"(sorted nsp)"*sort_by_nspikes}\n{("(first " + str(plot_first_ncells)+ " cells)")*(type(plot_first_ncells) is int) }', rotation=0, ha='right', va='center') else: ax.yaxis.set_ticklabels([]) if mode_i == n_components - 1: plf.colorbar(im) # Add ticks on the right side of the plots if col == nsubplots[1] - 1 and row % nsubrows != nsubrows - 1: plf.integerticks(ax, 3, which='y') ax.yaxis.tick_right() fig.suptitle( f'CCA components of {st.exp_foldername}\n{shufflespikes=} {n_components=}\n{sort_by_nspikes=}\n' + f'{select_cells=} {regularization=} {filter_length=}') fig.subplots_adjust(wspace=0.1, hspace=0.3) if savefig: fig.savefig(savedir / f'{figsavename}_cellsandcomponents.pdf') # plt.show() plt.close(fig) #%% fig_corrs = plt.figure() plt.plot(cancorrs, 'ko') # plt.ylim([0.17, 0.24]) plt.xlabel('Component index') plt.ylabel('Correlation') plt.title(f'Cannonical correlations {shufflespikes=}') if savefig: fig_corrs.savefig(savedir / f'{figsavename}_correlation_coeffs.pdf') # plt.show() plt.close(fig_corrs) fig_nlt, axes_nlt = plt.subplots(nrows, ncols, figsize=(10, 10)) stim_comps_flatter = stim_comps[:n_components].reshape( (n_components, 2 * filter_length[0])).T resp_comps_flatter = resp_comps[:n_components].reshape( (n_components, resp_comps.shape[1] * filter_length[1])).T # from IPython.core.debugger import Pdb; ipdb=Pdb(); ipdb.set_trace() # Reshape to perform the convolution as a matrix multiplication generator_stim = stimulus @ stim_comps_flatter generator_resp = spikes @ resp_comps_flatter for i, ax in enumerate(axes_nlt.flatten()): nonlinearity, bins = nlt.calc_nonlin(generator_resp[:, i], generator_stim[:, i]) # ax.scatter(generator_stim, generator_resp, s=1, alpha=0.5, facecolor='grey') ax.plot(bins, nonlinearity, 'k') if i == 0: all_nonlinearities = np.empty((n_components, *nonlinearity.shape)) all_bins = np.empty((n_components, *bins.shape)) all_nonlinearities[i, ...] = nonlinearity all_bins[i, ...] = bins nlt_xlims = [] nlt_ylims = [] for i, ax in enumerate(axes_nlt.flatten()): xlim = ax.get_xlim() ylim = ax.get_ylim() nlt_xlims.extend(xlim) nlt_ylims.extend(ylim) nlt_maxx, nlt_minx = max(nlt_xlims), min(nlt_xlims) nlt_maxy, nlt_miny = max(nlt_ylims), min(nlt_ylims) for i, ax in enumerate(axes_nlt.flatten()): ax.set_xlim([nlt_minx, nlt_maxx]) ax.set_ylim([nlt_miny, nlt_maxy]) for i, axes_row in enumerate(axes_nlt): for j, ax in enumerate(axes_row): if i == nrows - 1: ax.set_xlabel('Generator (motion)') if j == 0: ax.set_ylabel('Generator (cells)') else: ax.yaxis.set_ticklabels([]) ax.set_xlim([nlt_minx, nlt_maxx]) ax.set_ylim([nlt_miny, nlt_maxy]) fig_nlt.suptitle(f'Nonlinearities\n{figsavename}') if savefig: fig_nlt.savefig(savedir / f'{figsavename}_nonlinearity.png') plt.close(fig_nlt) keystosave = [ 'n_components', 'resp_comps', 'stim_comps', 'motionfilt_r', 'motionfilt_theta', 'resp_comps_sorted_nsp', 'select_cells', 'regularization', 'filter_length', 'all_nonlinearities', 'all_bins', 'cca_solver' ] datadict = dict() for key in keystosave: datadict[key] = locals()[key] np.savez(savedir / figsavename, **datadict)
# Only contrast data_c = np.load( os.path.join(st.exp_dir, 'data_analysis', st.stimname, gqmlabels[0], f'{stim}_{gqmlabels[0]}.npz')) # only Motion data_m = np.load( os.path.join(st.exp_dir, 'data_analysis', st.stimname, gqmlabels[1], f'{stim}_{gqmlabels[1]}.npz')) # Motion and contrast data_cm = np.load( os.path.join(st.exp_dir, 'data_analysis', st.stimname, gqmlabels[2], f'{stim}_{gqmlabels[2]}.npz')) # Exclude those with very few spikes cutoff = 0.1 # In units of spikes/s lowq = (st.allspikes().mean(axis=1) / st.frame_duration) < cutoff cc_c = data_c['cross_corrs'][~lowq] cc_m = data_m['cross_corrs'][~lowq] cc_cm = data_cm['cross_corrs'][~lowq] #%% Scatter fig, axes = plt.subplots( 2, 2, figsize=(5.5, 5), # sharex=True, sharey=True ) ax1, ax2, ax3, ax4 = axes.flat
return calculate_loglikelihood(kmu, spikes, stimulus, time_res) - calculate_ll0(spikes) if __name__ == '__main__': import matplotlib.pyplot as plt from omb import OMB import genlinmod_multidimensional as glmm import plotfuncs as plf # from driftinggratings import DriftingGratings exp, stim = '20180710', 8 # exp, stim = 'Kuehn', 13 st = OMB(exp, stim) species = st.metadata["animal"] allspikes = st.allspikes() data_cm = np.load( f'{st.stim_dir}/GLM_motioncontrast_xval/{st.stimnr}_GLM_motioncontrast_xval.npz' ) data_c = np.load( f'{st.stim_dir}/GLM_contrast_xval/{st.stimnr}_GLM_contrast_xval.npz') data_m = np.load( f'{st.stim_dir}/GLM_motion_xval/{st.stimnr}_GLM_motion_xval.npz') model_input = [('Contrast and motion', 3), ('Contrast', 1), ('Motion', 2)] logls = np.zeros((st.nclusters, 3)) # Exclude those with very few spikes cutoff = 0.2 # In units of spikes/s
import matplotlib.pyplot as plt from omb import OMB import analysis_scripts as asc import plotfuncs as plf from model_fitting_tools import packdims, shiftspikes, cart2pol exp, stim_nr = '20180710*kilosorted', 8 n_components = 6 st = OMB(exp, stim_nr) filter_length = st.filter_length spikes = st.allspikes() bgsteps = st.bgsteps for shift in [0]: print(shift) spikes = shiftspikes(st.allspikes(), shift) stimulus = packdims(st.bgsteps, filter_length) spikes = packdims(spikes, filter_length) cca = CCA(n_components=n_components, scale=True, max_iter=500, tol=1e-06, copy=True)
def omb_contrastmotion2dnonlin_Qcomps(exp, stim, nbins_nlt=9, cmap='Greys'): """ Calculate and plot the 2D nonlinearities for the OMB stimulus. Multiple components of the matrix Q for the motion. Parameters: ------ nbins_nlt: Number of bins to be used for dividing the generator signals into ranges with equal number of samples. """ st = OMB(exp, stim) # Motion and contrast data_cm = np.load(os.path.join(st.exp_dir, 'data_analysis', st.stimname, 'GQM_motioncontrast_val', f'{stim}_GQM_motioncontrast_val.npz')) qall = data_cm['Qall'] kall = data_cm['kall'] muall = data_cm['muall'] eigvecs = data_cm['eigvecs'] eigvals = data_cm['eigvals'] eiginds = [-1, 0] # activating, suppressing #HINT cross_corrs = data_cm['cross_corrs'] allspikes = st.allspikes() stim_mot = st.bgsteps.copy() # Bin dimension should be one greater than nonlinearity for pcolormesh # compatibility. Otherwise the last row and column of nonlinearity is not # plotted. all_bins_c = np.zeros((st.nclusters, nbins_nlt+1)) all_bins_r = np.zeros((st.nclusters, nbins_nlt+1)) nonlinearities = np.zeros((st.nclusters, nbins_nlt, nbins_nlt)) label = '2D-nonlin_Qallcomps_motion_kcontrast' row_labels = ['Activating', 'Suppresive'] column_labels = ['X', 'Y', r'$\sqrt{X^2 + Y^2}$'] savedir = os.path.join(st.stim_dir, label) os.makedirs(savedir, exist_ok=True) for i in range(st.nclusters): stim_con = st.contrast_signal_cell(i).squeeze() n = 3 # x, y, xy m = 2 # activating, suppressing fig = plt.figure(figsize=(n*5, m*5), constrained_layout=True) gs = fig.add_gridspec(m, n) axes = [] for _, eachgs in enumerate(gs): subgs = eachgs.subgridspec(2, 3, width_ratios=[4, 1, .2], height_ratios=[1, 4]) mainax = fig.add_subplot(subgs[1, 0]) axx = fig.add_subplot(subgs[0, 0], sharex=mainax) axy = fig.add_subplot(subgs[1, 1], sharey=mainax) cbax = fig.add_subplot(subgs[1, 2]) axes.append([axx, mainax, axy, cbax]) for k, eigind in enumerate(eiginds): generator_x = np.convolve(eigvecs[i, 0, :, eigind], stim_mot[0, :], 'full')[:-st.filter_length+1] generator_y = np.convolve(eigvecs[i, 1, :, eigind], stim_mot[1, :], 'full')[:-st.filter_length+1] # Calculate the magnitude of the vector formed by motion generators generators = np.vstack([generator_x, generator_y]) generator_xy = np.sqrt(np.sum(generators**2, axis=0)) # Project the contrast stimulus onto the linear filter generator_c = np.convolve(stim_con, kall[i, 2, :], 'full')[:-st.filter_length+1] spikes = allspikes[i, :] generators_motion = [generator_x, generator_y, generator_xy] for l, direction in enumerate(column_labels): nonlinearity, bins_c, bins_r = nlt.calc_nonlin_2d(spikes, generator_c, generators_motion[l], nr_bins=nbins_nlt) nonlinearity /= st.frame_duration all_bins_c[i, :] = bins_c all_bins_r[i, :] = bins_r nonlinearities[i, ...] = nonlinearity X, Y = np.meshgrid(bins_c, bins_r, indexing='ij') subaxes = axes[k*n+l] axmain = subaxes[1] axx = subaxes[0] axy = subaxes[2] cbax = subaxes[3] # Normally subplots turns off shared axis tick labels but # Gridspec does not do this plt.setp(axx.get_xticklabels(), visible=False) plt.setp(axy.get_yticklabels(), visible=False) im = axmain.pcolormesh(X, Y, nonlinearity, cmap=cmap) plf.integerticks(axmain, 6, which='xy') cb = plt.colorbar(im, cax=cbax) cb.outline.set_linewidth(0) cb.ax.set_xlabel('spikes/s') cb.ax.xaxis.set_label_position('top') plf.integerticks(cb.ax, 4, which='y') plf.integerticks(axx, 1, which='y') plf.integerticks(axy, 1, which='x') barkwargs = dict(alpha=.3, facecolor='k', linewidth=.5, edgecolor='w') axx.bar(nlt.bin_midpoints(bins_c), nonlinearity.mean(axis=1), width=np.ediff1d(bins_c), **barkwargs) axy.barh(nlt.bin_midpoints(bins_r), nonlinearity.mean(axis=0), height=np.ediff1d(bins_r), **barkwargs) plf.spineless(axx, 'b') plf.spineless(axy, 'l') if k == 0 and l == 0: axmain.set_xlabel('Projection onto linear contrast filter') axmain.set_ylabel(f'Projection onto Q component') if k == 0: axx.set_title(direction) if l == 0: axmain.text(-.3, .5, row_labels[k], va='center', rotation=90, transform=axmain.transAxes) fig.suptitle(f'{st.exp_foldername}\n{st.stimname}\n{st.clids[i]} ' f'2D nonlinearity nsp: {st.allspikes()[i, :].sum():<5.0f}') plt.subplots_adjust(top=.85) fig.savefig(os.path.join(savedir, st.clids[i]), bbox_inches='tight') plt.show() keystosave = ['nonlinearities', 'all_bins_c', 'all_bins_r', 'nbins_nlt'] datadict = {} for key in keystosave: datadict.update({key: locals()[key]}) npzfpath = os.path.join(savedir, f'{st.stimnr}_{label}.npz') np.savez(npzfpath, **datadict)
>>> contrast_avg = contrast.mean(axis=-1) >>> sta_corrected = subtract_avgcontrast(stas, contrast_avg) """ return sta - contrast_avg[None, :, :, None] exp, ombstimnr = '20180710_kilosorted', 8 checkerstimnr = 6 st = OMB(exp, ombstimnr, maxframes=None) choosecells = [8, 33, 61, 73, 79] nrcells = len(choosecells) all_spikes = st.allspikes()[choosecells, :] rw = asc.rolling_window(st.bgsteps, st.filter_length) motionstas = np.array(st.read_datafile()['stas'])[choosecells, :] motionstas /= all_spikes.sum(axis=(-1))[:, np.newaxis, np.newaxis] #%% Filter the stimuli # Euclidian norm motionstas_norm = motionstas / np.sqrt( (motionstas**2).sum(axis=-1))[:, :, None] bgsteps = st.bgsteps / np.sqrt(st.bgsteps.var()) rw = asc.rolling_window(bgsteps, st.filter_length)