def playsta(sta, frame_duration=None, cmap=None, centerzero=True, **kwargs): """ Create a looped animation for a single STA with 3 dimensions. Parameters --------- cmap: Colormap to be used. Defaults to the specified colormap in the config file. centerzero: Center the colormap around zero if True. interval: Frame rate for the animation in ms. repeat_delay: Time to wait before the animation is repeated in ms. Note ---- The returned animation can be saved like so: >>> ani = playsta(sta) >>> ani.save('wheretosave/sta.gif', writer='imagemagick', fps=10) """ check_interactive_backend() if cmap is None: cmap = iof.config('colormap') if centerzero: vmax = asc.absmax(sta) vmin = asc.absmin(sta) else: vmax, vmin = sta.max(), sta.min() ims = [] fig = plt.figure() ax = plt.gca() for i in range(sta.shape[-1]): im = ax.imshow(sta[:, :, i], animated=True, cmap=cmap, vmin=vmin, vmax=vmax) ims.append([im]) # Needs to be a list of lists ani = animation.ArtistAnimation(fig, ims, **kwargs) return ani
def stashow(sta, ax=None, cbar=True, **kwargs): """ Plot STA in a nice way with proper colormap and colorbar. STA can be single frame from checkerflicker or whole STA from stripeflicker. Following kwargs are available: imshow extent: Change the labels of the axes. [xmin, xmax, ymin, ymax] aspect: Aspect ratio of the image. 'auto', 'equal' cmap: Colormap to be used. Default is set in config colorbar size: Width of the colorbar as percentage of image dimension Default is 2% ticks: Where the ticks should be placed on the colorbar. format: Format for the tick labels. Default is '%.2f' Usage: ax = plt.subplot(111) stashow(sta, ax) """ vmax = asc.absmax(sta) vmin = asc.absmin(sta) # Make a dictionary for imshow and colorbar kwargs imshowkw = {'cmap': iof.config('colormap'), 'vmin': vmin, 'vmax': vmax} cbarkw = {'size': '2%', 'ticks': [vmin, vmax], 'format': '%.2f'} for key, val in kwargs.items(): if key in ['extent', 'aspect', 'cmap']: imshowkw.update({key: val}) elif key in ['size', 'ticks', 'format']: cbarkw.update({key: val}) else: raise ValueError(f'Unknown kwarg: {key}') if ax is None: ax = plt.gca() im = ax.imshow(sta, **imshowkw) spineless(ax) if cbar: colorbar(im, **cbarkw) return im
ax.set_title(f'Component {mode_i}', fontweight='bold') ax.xaxis.set_ticklabels([]) elif row % nsubrows == 1: ax.plot(motionfilt_r[mode_i, :], color='k') if col == 0: ax.set_ylabel('Magnitude') ax.set_ylim(motionfilt_r.min(), motionfilt_r.max()) ax.xaxis.set_ticklabels([]) elif row % nsubrows == 2: ax.plot(motionfilt_theta[mode_i, :], color='r') ax.yaxis.set_ticks([-np.pi, 0, np.pi]) ax.yaxis.set_ticklabels(['-π', 0, 'π']) ax.xaxis.set_ticklabels([]) elif row % nsubrows == nsubrows - 1: im = ax.imshow(cells[mode_i, :], cmap='RdBu_r', vmin=asc.absmin(cells), vmax=asc.absmax(cells), aspect='auto', interpolation='nearest') ax.set_xticks(np.array([0, .25, .5, .75, 1]) * cells.shape[-1]) ax.xaxis.set_ticklabels( np.round((ax.get_xticks() * st.frame_duration), 1)) ax.set_xlabel('Time [s]') if col == 0: ax.set_ylabel('Cells') if mode_i == n_components - 1: plf.colorbar(im) # fig.tight_layout() fig.subplots_adjust(wspace=0.3) fig.savefig( f'/Users/ycan/Downloads/2020-05-27_meeting/cca_DScells_shuffled.pdf') plt.show()
stc2 = calc_stc(spikes, stimulus, filter_length) fig, axes = plt.subplots(2, 1) axes[0].plot(eigvals, 'o') ax1 = axes[1] ax1.plot(sta) ax1.plot(eigvecs[:, 0]) ax1.plot(eigvecs[:, -1]) #%% sp_mean = spikes.mean() stim_covar = np.cov(rw, rowvar=False) term1 = stc + (2 * (sta @ sta.T)) / sp_mean bmel = np.linalg.inv(term1) @ sta cmel = (np.linalg.inv(stim_covar) - sp_mean * np.linalg.inv(term1)) / 2 fig2, axes2 = plt.subplots(2, 2) axes2 = axes2.flat axes2[0].plot(bmel) axes2[0].plot(sta, ls='dashed', color='grey', label='STA') axes2[1].imshow(cmel, cmap='seismic', vmin=asc.absmin(cmel), vmax=asc.absmax(cmel)) v, w = np.linalg.eigh(cmel) eiginds = [0, -1] axes2[2].plot(v, 'ko') axes2[2].plot(v[eiginds[0]], 'o') axes2[2].plot(v[eiginds[1]], 'o') axes2[3].plot(w[:, eiginds])
def cca_omb_components(exp: str, stim_nr: int, n_components: int = 6, regularization=None, filter_length=None, cca_solver: str = 'macke', maxframes=None, shufflespikes: bool = False, exclude_allzero_spike_rows: bool = True, savedir: str = None, savefig: bool = True, sort_by_nspikes: bool = True, select_cells: list = None, plot_first_ncells: int = None, whiten: bool = False): """ Analyze OMB responses using cannonical correlation analysis and plot the results. Parameters --- n_components: Number of components that will be requested from the CCA anaylsis. More numbers mean the algortihm will stop at a later point. That means components of analyses with fewer n_components are going to be identical to the first n components of the higher-number component analyses. regularization: The regularization parameter to be passed onto rcca.CCA. Not relevant for macke filter_length: The length of the time window to be considered in the past for the stimulus and the responses. Can be different for stimulus and response, if a tuple is given. cca_solver: Which CCA solver to use. Options are `rcca` and `macke`(default) maxframes: int Number of frames to load in the the experiment object. Used to avoid memory and performance issues. shufflespikes: bool Whether to randomize the spikes, to validate the results exclude_allzero_spike_rows: Exclude all cells which have zero spikes for the duration of the stimulus. savedir: str Custom directory to save the figures and data files. If None, will be saved in the experiment directory under appropritate path. savefig: bool Whether to save the figures. sort_by_nspikes: bool Wheter to sort the cell weights array by the number of spikes during the stimulus. select_cells: list A list of indexes for the subset of cells to perform the analysis for. plot_first_ncells: int Number of cells to plot in the cell plots. """ if regularization is None: regularization = 0 st = OMB(exp, stim_nr, maxframes=maxframes) if filter_length is None: filter_length = st.filter_length if type(filter_length) is int: filter_length = (filter_length, filter_length) if type(savedir) is str: savedir = Path(savedir) if savedir is None: savedir = st.stim_dir / 'CCA' savedir.mkdir(exist_ok=True, parents=True) spikes = st.allspikes() nonzerospikerows = ~np.isclose(spikes.sum(axis=1), 0) # Set the mean to zero for spikes spikes -= spikes.mean(axis=1)[:, None] bgsteps = st.bgsteps if select_cells is not None: if type(select_cells) is not np.array: select_cells = np.array(select_cells) spikes = spikes[select_cells] st.nclusters = len(select_cells) # Convert to list for better string representation # np.array is printed as "array([....])" # with newline characters which is problematic in filenames select_cells = list(select_cells) # Exclude rows that have no spikes throughout if exclude_allzero_spike_rows: spikes = spikes[nonzerospikerows, :] nspikes_percell = spikes.sum(axis=1) if shufflespikes: spikes = spikeshuffler.shufflebyrow(spikes) figsavename = f'{n_components=}_{shufflespikes=}_{select_cells=}_{regularization=}_{filter_length=}_{whiten=}' # If the file length gets too long due to the list of selected cells, summarize it. if len(figsavename) > 200: figsavename = f'{n_components=}_{shufflespikes=}_select_cells={len(select_cells)}cells-index{select_cells[0]}to{select_cells[-1]}_{regularization=}_{filter_length=}_{whiten=}' #sp_train, sp_test, stim_train, stim_test = train_test_split(spikes, bgsteps) stimulus = mft.packdims(st.bgsteps, filter_length[0]) spikes = mft.packdims(spikes, filter_length[1]) if cca_solver == 'rcca': resp_comps, stim_comps, cancorrs = cca_rcca(spikes, stimulus, filter_length, n_components, regularization, whiten) # cells = np.swapaxes(spikes_res, 1, 0) # cells = cells.reshape((n_components, st.nclusters, filter_length[1])) # motionfilt_x = cca.ws[1][:filter_length[0]].T # motionfilt_y = cca.ws[1][filter_length[0]:].T elif cca_solver == 'macke': resp_comps, stim_comps, cancorrs = cca_macke(spikes, stimulus, filter_length, n_components) nsp_argsorted = np.argsort(nspikes_percell) resp_comps_sorted_nsp = resp_comps[:, nsp_argsorted, :] if sort_by_nspikes: resp_comps_toplot = resp_comps_sorted_nsp else: resp_comps_toplot = resp_comps if plot_first_ncells is not None: resp_comps_toplot = resp_comps_toplot[:, :plot_first_ncells, ...] motionfilt_r, motionfilt_theta = mft.cart2pol(stim_comps[:, 0, :], stim_comps[:, 1, :]) #%% nrows, ncols = plf.numsubplots(n_components) fig_cells, axes_cells = plt.subplots(nrows, ncols, figsize=(10, 10)) for i in range(n_components): ax = axes_cells.flat[i] im = ax.imshow(resp_comps[i, :], cmap='RdBu_r', vmin=asc.absmin(resp_comps), vmax=asc.absmax(resp_comps), aspect='auto', interpolation='nearest') ax.set_title(f'{i}') fig_cells.suptitle(f'Cells default order {shufflespikes=}') if savefig: fig_cells.savefig(savedir / f'{figsavename}_cells_default_order.pdf') plt.close(fig_cells) nsubplots = plf.numsubplots(n_components) height_list = [1, 1, 1, 3] # ratios of the plots in each component # Create a time vector for the stimulus plots t_stim = -np.arange(0, filter_length[0] * st.frame_duration, st.frame_duration)[::-1] * 1000 t_response = -np.arange(0, filter_length[1] * st.frame_duration, st.frame_duration)[::-1] * 1000 xtick_loc_params = dict(nbins=4, steps=[2, 5, 10], integer=True) nsubrows = len(height_list) height_ratios = nsubplots[0] * height_list fig, axes = plt.subplots(nrows=nsubplots[0] * nsubrows, ncols=nsubplots[1], gridspec_kw={'height_ratios': height_ratios}, figsize=(11, 10)) for row, ax_row in enumerate(axes): for col, ax in enumerate(ax_row): mode_i = int(row / nsubrows) * nsubplots[1] + col # ax.text(0.5, 0.5, f'{mode_i}') ax.set_yticks([]) # Plot motion filters if row % nsubrows == 0: ax.plot(t_stim, stim_comps[mode_i, 0, :], marker='o', markersize=1) ax.plot(t_stim, stim_comps[mode_i, 1, :], marker='o', markersize=1) if col == 0: ax.set_ylabel('Motion', rotation=0, ha='right', va='center') ax.set_ylim(stim_comps.min(), stim_comps.max()) # Draw a horizontal line for zero and prevent rescaling of x-axis xlims = ax.get_xlim() ax.hlines(0, *ax.get_xlim(), colors='k', linestyles='dashed', alpha=0.3) ax.set_xlim(*xlims) # ax.set_title(f'Component {mode_i}', fontweight='bold') ax.xaxis.set_major_locator(MaxNLocator(**xtick_loc_params)) if not mode_i == 0 or filter_length[0] == filter_length[1]: ax.xaxis.set_ticklabels([]) else: ax.tick_params(axis='x', labelsize=8) # Plot magnitude of motion elif row % nsubrows == 1: ax.plot(t_stim, motionfilt_r[mode_i, :], color='k', marker='o', markersize=1) if col == 0: ax.set_ylabel('Magnitude', rotation=0, ha='right', va='center') ax.set_ylim(motionfilt_r.min(), motionfilt_r.max()) ax.xaxis.set_ticklabels([]) ax.xaxis.set_major_locator(MaxNLocator(**xtick_loc_params)) # Plot direction of motion elif row % nsubrows == 2: ax.plot(t_stim, motionfilt_theta[mode_i, :], color='r', marker='o', markersize=1) if mode_i == 0: ax.yaxis.set_ticks([-np.pi, 0, np.pi]) ax.yaxis.set_ticklabels(['-π', 0, 'π']) ax.xaxis.set_ticklabels([]) ax.xaxis.set_major_locator(MaxNLocator(**xtick_loc_params)) # Plot cell weights elif row % nsubrows == nsubrows - 1: im = ax.imshow(resp_comps_toplot[mode_i, :], cmap='RdBu_r', vmin=asc.absmin(resp_comps), vmax=asc.absmax(resp_comps), aspect='auto', interpolation='nearest', extent=[ t_response[0], t_response[-1], 0, resp_comps_toplot.shape[1] ]) ax.xaxis.set_major_locator(MaxNLocator(**xtick_loc_params)) if row == axes.shape[0] - 1: ax.set_xlabel('Time before spike [ms]') # ax.set_xticks(np.array([0, .25, .5, .75, 1]) * cells_toplot.shape[-1]) # ax.xaxis.set_ticklabels(-np.round((ax.get_xticks()*st.frame_duration), 2)[::-1]) else: ax.xaxis.set_ticklabels([]) plf.integerticks(ax, 5, which='y') if col == 0: ax.set_ylabel( f'Cells\n{"(sorted nsp)"*sort_by_nspikes}\n{("(first " + str(plot_first_ncells)+ " cells)")*(type(plot_first_ncells) is int) }', rotation=0, ha='right', va='center') else: ax.yaxis.set_ticklabels([]) if mode_i == n_components - 1: plf.colorbar(im) # Add ticks on the right side of the plots if col == nsubplots[1] - 1 and row % nsubrows != nsubrows - 1: plf.integerticks(ax, 3, which='y') ax.yaxis.tick_right() fig.suptitle( f'CCA components of {st.exp_foldername}\n{shufflespikes=} {n_components=}\n{sort_by_nspikes=}\n' + f'{select_cells=} {regularization=} {filter_length=}') fig.subplots_adjust(wspace=0.1, hspace=0.3) if savefig: fig.savefig(savedir / f'{figsavename}_cellsandcomponents.pdf') # plt.show() plt.close(fig) #%% fig_corrs = plt.figure() plt.plot(cancorrs, 'ko') # plt.ylim([0.17, 0.24]) plt.xlabel('Component index') plt.ylabel('Correlation') plt.title(f'Cannonical correlations {shufflespikes=}') if savefig: fig_corrs.savefig(savedir / f'{figsavename}_correlation_coeffs.pdf') # plt.show() plt.close(fig_corrs) fig_nlt, axes_nlt = plt.subplots(nrows, ncols, figsize=(10, 10)) stim_comps_flatter = stim_comps[:n_components].reshape( (n_components, 2 * filter_length[0])).T resp_comps_flatter = resp_comps[:n_components].reshape( (n_components, resp_comps.shape[1] * filter_length[1])).T # from IPython.core.debugger import Pdb; ipdb=Pdb(); ipdb.set_trace() # Reshape to perform the convolution as a matrix multiplication generator_stim = stimulus @ stim_comps_flatter generator_resp = spikes @ resp_comps_flatter for i, ax in enumerate(axes_nlt.flatten()): nonlinearity, bins = nlt.calc_nonlin(generator_resp[:, i], generator_stim[:, i]) # ax.scatter(generator_stim, generator_resp, s=1, alpha=0.5, facecolor='grey') ax.plot(bins, nonlinearity, 'k') if i == 0: all_nonlinearities = np.empty((n_components, *nonlinearity.shape)) all_bins = np.empty((n_components, *bins.shape)) all_nonlinearities[i, ...] = nonlinearity all_bins[i, ...] = bins nlt_xlims = [] nlt_ylims = [] for i, ax in enumerate(axes_nlt.flatten()): xlim = ax.get_xlim() ylim = ax.get_ylim() nlt_xlims.extend(xlim) nlt_ylims.extend(ylim) nlt_maxx, nlt_minx = max(nlt_xlims), min(nlt_xlims) nlt_maxy, nlt_miny = max(nlt_ylims), min(nlt_ylims) for i, ax in enumerate(axes_nlt.flatten()): ax.set_xlim([nlt_minx, nlt_maxx]) ax.set_ylim([nlt_miny, nlt_maxy]) for i, axes_row in enumerate(axes_nlt): for j, ax in enumerate(axes_row): if i == nrows - 1: ax.set_xlabel('Generator (motion)') if j == 0: ax.set_ylabel('Generator (cells)') else: ax.yaxis.set_ticklabels([]) ax.set_xlim([nlt_minx, nlt_maxx]) ax.set_ylim([nlt_miny, nlt_maxy]) fig_nlt.suptitle(f'Nonlinearities\n{figsavename}') if savefig: fig_nlt.savefig(savedir / f'{figsavename}_nonlinearity.png') plt.close(fig_nlt) keystosave = [ 'n_components', 'resp_comps', 'stim_comps', 'motionfilt_r', 'motionfilt_theta', 'resp_comps_sorted_nsp', 'select_cells', 'regularization', 'filter_length', 'all_nonlinearities', 'all_bins', 'cca_solver' ] datadict = dict() for key in keystosave: datadict[key] = locals()[key] np.savez(savedir / figsavename, **datadict)
import matplotlib.pyplot as plt cells = cca.x_weights_.T motionfilt_x = cca.y_weights_[:filter_length].T motionfilt_y = cca.y_weights_[filter_length:].T fig, axes = plt.subplots(*plf.numsubplots(n_components)) for i, ax in enumerate(axes.flat): ax.plot(motionfilt_x[i, :]) ax.plot(motionfilt_y[i, :]) ax.set_ylim(cca.y_weights_.min(), cca.y_weights_.max()) fig.suptitle('Stimulus components for each CCA filter') plt.show() im = plt.imshow(cells, cmap='RdBu_r', vmin=asc.absmin(cells), vmax=asc.absmax(cells)) plt.ylabel('Components') plt.xlabel('Cells') plf.colorbar(im) plt.show() plt.plot(cca.coef_[:, 0], cca.coef_[:, 1], 'ko') plt.show() offsets = np.arange(n_components)*.6 plt.plot(cca.y_weights_ + np.tile(offsets, (stimulus.shape[1], 1)), lw=0.5) plt.hlines(offsets, 0, stimulus.shape[1], lw=.2)
ax.set_ylim(cca.y_weights_.min(), cca.y_weights_.max()) ax.set_title(f'Component {mode_i}', fontweight='bold') ax.xaxis.set_ticklabels([]) elif row % nsubrows == 1: ax.plot(motionfilt_r[mode_i, :], color='k') if col==0: ax.set_ylabel('Magnitude') ax.set_ylim(motionfilt_r.min(), motionfilt_r.max()) ax.xaxis.set_ticklabels([]) elif row % nsubrows == 2: ax.plot(motionfilt_theta[mode_i, :], color='r') ax.yaxis.set_ticks([-np.pi, 0, np.pi]) ax.yaxis.set_ticklabels(['-π', 0, 'π']) ax.xaxis.set_ticklabels([]) elif row % nsubrows == nsubrows-1: im = ax.imshow(cells[mode_i, :], cmap='RdBu_r', vmin=asc.absmin(cells), vmax=asc.absmax(cells), aspect='auto', interpolation='nearest') ax.set_xticks(np.array([0, .25, .5, .75, 1]) * cells.shape[-1]) ax.xaxis.set_ticklabels(np.round((ax.get_xticks()*st.frame_duration), 1)) ax.set_xlabel('Time [s]') if col==0: ax.set_ylabel('Cells') # fig.tight_layout() fig.subplots_adjust(wspace=0.3) selcels = '_'.join(map(str, selected_cells)) fig.savefig(f'/Users/ycan/Downloads/2020-05-27_meeting/{selcels}_shuffled{shuffle_spikes}.pdf') plt.show() #%%
def stabrowser(sta, frame_duration=None, cmap=None, centerzero=True, **kwargs): """ Returns an interactive plot to browse an spatiotemporal STA. Requires an interactive matplotlib backend. Parameters -------- sta: Numpy array containing the STA. Last dimension should index time. frame_duration: Time between each frame. (optional) cmap: Colormap to use. centerzero: Whether to center the colormap around zero for diverging colormaps. Example ------ >>> print(sta.shape) # (xpixels, ypixels, time) (75, 100, 40) >>> fig, slider = stabrowser(sta, frame_duration=1/60) Notes ----- When calling the function, the slider is returned to prevent the reference to it getting destroyed and to keep it interactive. The dummy variable `_` can also be used. """ check_interactive_backend() if cmap is None: cmap = iof.config('colormap') if centerzero: vmax = asc.absmax(sta) vmin = asc.absmin(sta) else: vmax, vmin = sta.max(), sta.min() imshowkwargs = dict(cmap=cmap, vmax=vmax, vmin=vmin, **kwargs) fig = plt.figure() ax = fig.add_axes([0.1, 0.1, 0.8, 0.8]) initial_frame = 5 axsl = fig.add_axes([0.25, 0.05, 0.65, 0.03]) # For the slider to remain interactive, a reference to it should # be kept, so it set to a variable and is returned by the function slider_t = Slider(axsl, 'Frame before spike', 0, sta.shape[-1] - 1, valinit=initial_frame, valstep=1, valfmt='%2.0f') def update(frame): frame = int(frame) im = ax.get_images()[0] im.set_data(sta[..., frame]) if frame_duration is not None: fig.suptitle(f'{frame*frame_duration*1000:4.0f} ms') fig.canvas.draw_idle() slider_t.on_changed(update) ax.imshow(sta[..., initial_frame], **imshowkwargs) ax.set_axis_off() plt.tight_layout() plt.subplots_adjust(wspace=.01, hspace=.01) return fig, slider_t
def multistabrowser(stas, frame_duration=None, normalize=True, cmap=None, centerzero=True, **kwargs): """ Returns an interactive plot to browse multiple spatiotemporal STAs at the same time. Requires an interactive matplotlib backend. Parameters -------- stas: Numpy array containing STAs. First dimension should index individual cells, last dimension should index time. Alternatively, this could be a list of numpy arrays. frame_duration: Time between each frame. (optional) normalize: Whether to normalize each STA cmap: Colormap to use. centerzero: Whether to center the colormap around zero for diverging colormaps. Example ------ >>> print(stas.shape) # (nrcells, xpixels, ypixels, time) (36, 75, 100, 40) >>> fig, slider = stabrowser(stas, frame_duration=1/60) Notes ----- When calling the function, the slider is returned to prevent the reference to it getting destroyed and to keep it interactive. The dummy variable `_` can also be used. """ check_interactive_backend() if isinstance(stas, list): stas = np.array(stas) if normalize: stas = asc.normalize(stas) if cmap is None: cmap = iof.config('colormap') if centerzero: vmax = asc.absmax(stas) vmin = asc.absmin(stas) else: vmax, vmin = stas.max(), stas.min() imshowkwargs = dict(cmap=cmap, vmax=vmax, vmin=vmin, **kwargs) rows, cols = numsubplots(stas.shape[0]) fig, axes = plt.subplots(rows, cols, sharex=True, sharey=True) initial_frame = 5 axsl = fig.add_axes([0.25, 0.05, 0.65, 0.03]) # For the slider to remain interactive, a reference to it should # be kept, so it set to a variable and is returned by the function slider_t = Slider(axsl, 'Frame before spike', 0, stas.shape[-1] - 1, valinit=initial_frame, valstep=1, valfmt='%2.0f') def update(frame): frame = int(frame) for i in range(rows): for j in range(cols): # Calculate the flattened index, equivalent to i*cols+j flat_idx = np.ravel_multi_index([i, j], (rows, cols)) if flat_idx < stas.shape[0]: im = axes[i, j].get_images()[0] im.set_data(stas[flat_idx, ..., frame]) if frame_duration is not None: fig.suptitle(f'{frame*frame_duration*1000:4.0f} ms') fig.canvas.draw_idle() slider_t.on_changed(update) for i in range(rows): for j in range(cols): flat_idx = np.ravel_multi_index([i, j], (rows, cols)) ax = axes[i, j] if flat_idx < stas.shape[0]: ax.imshow(stas[i * cols + j, ..., initial_frame], **imshowkwargs) ax.set_axis_off() plt.tight_layout() plt.subplots_adjust(wspace=.01, hspace=.01) return fig, slider_t
g = sns.jointplot( 'x', 'y', sizes, 'scatter', # shade_lowest = False, # , xlim=[0, 6], ylim=[0, 6] ) #%% import plotfuncs as plf plf.absmax() # stop the script for i in range(len(stas)): sta = stas[i] plt.imshow(sta[..., max_inds[i][-1]], cmap='RdBu_r', vmax=asc.absmax(sta), vmin=asc.absmin(sta)) drawellipse(all_pars[i]) plt.title(f'{i}') plt.show() #%% fig, sl = plf.multistabrowser(stas) for i in range(len(stas)): ax = fig.axes[i] drawellipse(all_pars[i], ax)