def playsta(sta, frame_duration=None, cmap=None, centerzero=True, **kwargs): """ Create a looped animation for a single STA with 3 dimensions. Parameters --------- cmap: Colormap to be used. Defaults to the specified colormap in the config file. centerzero: Center the colormap around zero if True. interval: Frame rate for the animation in ms. repeat_delay: Time to wait before the animation is repeated in ms. Note ---- The returned animation can be saved like so: >>> ani = playsta(sta) >>> ani.save('wheretosave/sta.gif', writer='imagemagick', fps=10) """ check_interactive_backend() if cmap is None: cmap = iof.config('colormap') if centerzero: vmax = asc.absmax(sta) vmin = asc.absmin(sta) else: vmax, vmin = sta.max(), sta.min() ims = [] fig = plt.figure() ax = plt.gca() for i in range(sta.shape[-1]): im = ax.imshow(sta[:, :, i], animated=True, cmap=cmap, vmin=vmin, vmax=vmax) ims.append([im]) # Needs to be a list of lists ani = animation.ArtistAnimation(fig, ims, **kwargs) return ani
def stashow(sta, ax=None, cbar=True, **kwargs): """ Plot STA in a nice way with proper colormap and colorbar. STA can be single frame from checkerflicker or whole STA from stripeflicker. Following kwargs are available: imshow extent: Change the labels of the axes. [xmin, xmax, ymin, ymax] aspect: Aspect ratio of the image. 'auto', 'equal' cmap: Colormap to be used. Default is set in config colorbar size: Width of the colorbar as percentage of image dimension Default is 2% ticks: Where the ticks should be placed on the colorbar. format: Format for the tick labels. Default is '%.2f' Usage: ax = plt.subplot(111) stashow(sta, ax) """ vmax = asc.absmax(sta) vmin = asc.absmin(sta) # Make a dictionary for imshow and colorbar kwargs imshowkw = {'cmap': iof.config('colormap'), 'vmin': vmin, 'vmax': vmax} cbarkw = {'size': '2%', 'ticks': [vmin, vmax], 'format': '%.2f'} for key, val in kwargs.items(): if key in ['extent', 'aspect', 'cmap']: imshowkw.update({key: val}) elif key in ['size', 'ticks', 'format']: cbarkw.update({key: val}) else: raise ValueError(f'Unknown kwarg: {key}') if ax is None: ax = plt.gca() im = ax.imshow(sta, **imshowkw) spineless(ax) if cbar: colorbar(im, **cbarkw) return im
def read_spikesheet(exp_name, cutoff=4, defaultpath=True, onlymetadata=False): """ Read metadata and cluster information from spike sorting file (manually created during spike sorting), return good clusters. Parameters: ----------- exp_name: Experiment name for the directory that contains the .xlsx or .ods file. Possible file names may be set in the configuration file. Fallback/default name is 'spike_sorting.[ods|xlsx]'. cutoff: Worst rating that is tolerated for analysis. Default is 4. The source of this value is manual rating of each cluster. defaultpath: Whether to iterate over all possible file names in exp_dir. If False, the full path to the file should be supplied in exp_name. onlymetadata: To read ods and return only the metadata information Returns: -------- clusters: Channel number, cluster number and rating of those clusters that match the cutoff criteria in a numpy array. metadata: Information about the experiment in a dictionary. Raises: ------- FileNotFoundError: If no spike sorting file can be located. ValueError: If the spike sorting file containes incomplete information. Notes: ------ The script assumes adherence to defined cell locations for metadata and cluster information. If changed undefined behavior may occur. """ if defaultpath: exp_dir = iof.exp_dir_fixer(exp_name) filenames = iof.config('spike_sorting_filenames') for filename in filenames: filepath = os.path.join(exp_dir, filename) if iskilosorted(exp_name) and not onlymetadata: import readks return readks.read_spikesheet_ks(exp_name) elif os.path.isfile(filepath + '.ods'): filepath += '.ods' meta_keys = [0, 0, 1, 25] meta_vals = [1, 0, 2, 25] cluster_chnl = [4, 0, 2000, 1] cluster_cltr = [4, 4, 2000, 5] cluster_rtng = [4, 5, 2000, 6] break elif os.path.isfile(filepath + '.xlsx'): filepath += '.xlsx' meta_keys = [4, 1, 25, 2] meta_vals = [4, 5, 25, 6] cluster_chnl = [51, 1, 2000, 2] cluster_cltr = [51, 5, 2000, 6] cluster_rtng = [51, 6, 2000, 7] break else: raise FileNotFoundError('Spike sorting file (ods/xlsx) not found.') else: filepath = exp_name sheet = np.array(pyexcel.get_array(file_name=filepath, sheets=[0])) meta_keys = sheet[meta_keys[0]:meta_keys[2], meta_keys[1]:meta_keys[3]] meta_vals = sheet[meta_vals[0]:meta_vals[2], meta_vals[1]:meta_vals[3]] metadata = dict(zip(meta_keys.ravel(), meta_vals.ravel())) if onlymetadata: return metadata # Concatenate cluster information clusters = sheet[cluster_chnl[0]:cluster_chnl[2], cluster_chnl[1]:cluster_chnl[3]] cl = np.argmin(clusters.shape) clusters = np.append(clusters, sheet[cluster_cltr[0]:cluster_cltr[2], cluster_cltr[1]:cluster_cltr[3]], axis=cl) clusters = np.append(clusters, sheet[cluster_rtng[0]:cluster_rtng[2], cluster_rtng[1]:cluster_rtng[3]], axis=cl) if cl != 1: clusters = clusters.T clusters = clusters[np.any(clusters != [['', '', '']], axis=1)] # The channels with multiple clusters have an empty line after the first # line. Fill the empty lines using the first line of each channel. for i, c in enumerate(clusters[:, 0]): if c != '': nr = c else: clusters[i, 0] = nr if '' in clusters: rowcol = (np.where(clusters == '')[1 - cl][0] + 1 + cluster_chnl[1 - cl]) raise ValueError('Spike sorting file is missing information in ' '{} {}.'.format(['column', 'row'][cl], rowcol)) clusters = clusters.astype(int) # Sort the clusters in ascending order based on channel number # Normal sort function messes up the other columns for some reason # so we explicitly use lexsort for the columns containing channel nrs # Order of the columns given in lexsort are in reverse sorted_idx = np.lexsort((clusters[:, 1], clusters[:, 0])) clusters = clusters[sorted_idx, :] # Filter according to quality cutoff clusters = clusters[clusters[:, 2] <= cutoff] return clusters, metadata
def checkerflickerplusanalyzer(exp_name, stimulusnr, clusterstoanalyze=None, frametimingsfraction=None, cutoff=4): """ Analyzes checkerflicker-like data, typically interspersed stimuli in between chunks of checkerflicker. e.g. checkerflickerplusmovie, frozennoise Parameters: ---------- exp_name: Experiment name. stimulusnr: Number of the stimulus to be analyzed. clusterstoanalyze: Number of clusters should be analyzed. Default is None. First N cells will be analyzed if this parameter is given. In case of long recordings it might make sense to first look at a subset of cells before starting to analyze the whole dataset. frametimingsfraction: Fraction of the recording to analyze. Should be a number between 0 and 1. e.g. 0.3 will analyze the first 30% of the whole recording. cutoff: Worst rating that is wanted for the analysis. Default is 4. The source of this value is manual rating of each cluster. """ exp_dir = iof.exp_dir_fixer(exp_name) stimname = iof.getstimname(exp_dir, stimulusnr) exp_name = os.path.split(exp_dir)[-1] clusters, metadata = asc.read_spikesheet(exp_dir, cutoff=cutoff) # Check that the inputs are as expected. if clusterstoanalyze: if clusterstoanalyze > len(clusters[:, 0]): warnings.warn('clusterstoanalyze is larger ' 'than number of clusters in dataset. ' 'All cells will be included.') clusterstoanalyze = None if frametimingsfraction: if not 0 < frametimingsfraction < 1: raise ValueError('Invalid input for frametimingsfraction: {}. ' 'It should be a number between 0 and 1' ''.format(frametimingsfraction)) scr_width = metadata['screen_width'] scr_height = metadata['screen_height'] refresh_rate = metadata['refresh_rate'] parameters = asc.read_parameters(exp_dir, stimulusnr) stx_h = parameters['stixelheight'] stx_w = parameters['stixelwidth'] # Check whether any parameters are given for margins, calculate # screen dimensions. marginkeys = ['tmargin', 'bmargin', 'rmargin', 'lmargin'] margins = [] for key in marginkeys: margins.append(parameters.get(key, 0)) # Subtract bottom and top from vertical dimension; left and right # from horizontal dimension scr_width = scr_width - sum(margins[2:]) scr_height = scr_height - sum(margins[:2]) nblinks = parameters['Nblinks'] bw = parameters.get('blackwhite', False) # Gaussian stimuli are not supported yet, we need to ensure we # have a black and white stimulus if bw is not True: raise ValueError('Gaussian stimuli are not supported yet!') seed = parameters.get('seed', -1000) sx, sy = scr_height / stx_h, scr_width / stx_w # Make sure that the number of stimulus pixels are integers # Rounding down is also possible but might require # other considerations. if sx % 1 == 0 and sy % 1 == 0: sx, sy = int(sx), int(sy) else: raise ValueError('sx and sy must be integers') filter_length, frametimings = asc.ft_nblinks(exp_dir, stimulusnr) if parameters['stimulus_type'] in [ 'FrozenNoise', 'checkerflickerplusmovie' ]: runfr = parameters['RunningFrames'] frofr = parameters['FrozenFrames'] # To generate the frozen noise, a second seed is used. # The default value of this is -10000 as per StimulateOpenGL secondseed = parameters.get('secondseed', -10000) if parameters['stimulus_type'] == 'checkerflickerplusmovie': mblinks = parameters['Nblinksmovie'] # Retrivee the number of frames (files) from parameters['path'] ipath = PureWindowsPath(parameters['path']).as_posix() repldict = iof.config('stimuli_path_replace') for needle, repl in repldict.items(): ipath = ipath.replace(needle, repl) ipath = os.path.normpath(ipath) # Windows compatiblity moviefr = len([ name for name in os.listdir(ipath) if os.path.isfile(os.path.join(ipath, name)) and name.lower().endswith('.raw') ]) noiselen = (runfr + frofr) * nblinks movielen = moviefr * mblinks triallen = noiselen + movielen ft_on, ft_off = asc.readframetimes(exp_dir, stimulusnr, returnoffsets=True) frametimings = np.empty(ft_on.shape[0] * 2, dtype=float) frametimings[::2] = ft_on frametimings[1::2] = ft_off import math ntrials = math.floor(frametimings.size / triallen) trials = np.zeros((ntrials, runfr + frofr + moviefr)) for t in range(ntrials): frange = frametimings[t * triallen:(t + 1) * triallen] trials[t, :runfr + frofr] = frange[:noiselen][::nblinks] trials[t, runfr + frofr:] = frange[noiselen:][::mblinks] frametimings = trials.ravel() filter_length = np.int(np.round(.666 * refresh_rate / nblinks)) # Add frozen movie to frozen noise (for masking) frofr += moviefr savefname = str(stimulusnr) + '_data' if clusterstoanalyze: clusters = clusters[:clusterstoanalyze, :] print('Analyzing first %s cells' % clusterstoanalyze) savefname += '_' + str(clusterstoanalyze) + 'cells' if frametimingsfraction: frametimingsindex = int(len(frametimings) * frametimingsfraction) frametimings = frametimings[:frametimingsindex] print('Analyzing first {}% of' ' the recording'.format(frametimingsfraction * 100)) savefname += '_' + str(frametimingsfraction).replace('.', '') + 'fraction' frame_duration = np.average(np.ediff1d(frametimings)) total_frames = frametimings.shape[0] all_spiketimes = [] # Store spike triggered averages in a list containing correct shaped # arrays stas = [] for i in range(len(clusters[:, 0])): spiketimes = asc.read_raster(exp_dir, stimulusnr, clusters[i, 0], clusters[i, 1]) spikes = asc.binspikes(spiketimes, frametimings) all_spiketimes.append(spikes) stas.append(np.zeros((sx, sy, filter_length))) # Separate out the repeated parts all_spiketimes = np.array(all_spiketimes) mask = runfreezemask(total_frames, runfr, frofr, refresh_rate) repeated_spiketimes = all_spiketimes[:, ~mask] run_spiketimes = all_spiketimes[:, mask] # We need to cut down the total_frames by the same amount # as spiketimes total_run_frames = run_spiketimes.shape[1] # To be able to use the same code as checkerflicker analyzer, # convert to list again. run_spiketimes = list(run_spiketimes) # Empirically determined to be best for 32GB RAM desired_chunk_size = 21600000 # Length of the chunks (specified in number of frames) chunklength = int(desired_chunk_size / (sx * sy)) chunksize = chunklength * sx * sy nrofchunks = int(np.ceil(total_run_frames / chunklength)) print(f'\nAnalyzing {stimname}.\nTotal chunks: {nrofchunks}') time = startime = datetime.datetime.now() timedeltas = [] quals = np.zeros(len(stas)) frame_counter = 0 for i in range(nrofchunks): randnrs, seed = randpy.ranb(seed, chunksize) # Reshape and change 0's to -1's stimulus = np.reshape(randnrs, (sx, sy, chunklength), order='F') * 2 - 1 del randnrs # Range of indices we are interested in for the current chunk if (i + 1) * chunklength < total_run_frames: chunkind = slice(i * chunklength, (i + 1) * chunklength) chunkend = chunklength else: chunkind = slice(i * chunklength, None) chunkend = total_run_frames - i * chunklength for k in range(filter_length, chunkend - filter_length + 1): stim_small = stimulus[:, :, k - filter_length + 1:k + 1][:, :, ::-1] for j in range(clusters.shape[0]): spikes = run_spiketimes[j][chunkind] if spikes[k] != 0: stas[j] += spikes[k] * stim_small qual = np.array([]) for c in range(clusters.shape[0]): qual = np.append(qual, asc.staquality(stas[c])) quals = np.vstack((quals, qual)) # Draw progress bar width = 50 # Number of characters prog = i / (nrofchunks - 1) bar_complete = int(prog * width) bar_noncomplete = width - bar_complete timedeltas.append(msc.timediff(time)) # Calculate running avg avgelapsed = np.mean(timedeltas) elapsed = np.sum(timedeltas) etc = startime + elapsed + avgelapsed * (nrofchunks - i) sys.stdout.flush() sys.stdout.write('\r{}{} |{:4.1f}% ETC: {}'.format( '█' * bar_complete, '-' * bar_noncomplete, prog * 100, etc.strftime("%a %X"))) time = datetime.datetime.now() sys.stdout.write('\n') # Remove the first row which is full of random nrs. quals = quals[1:, :] max_inds = [] spikenrs = np.array([spikearr.sum() for spikearr in run_spiketimes]) for i in range(clusters.shape[0]): with warnings.catch_warnings(): warnings.filterwarnings('ignore', '.*true_divide*.') stas[i] = stas[i] / spikenrs[i] # Find the pixel with largest absolute value max_i = np.squeeze( np.where(np.abs(stas[i]) == np.max(np.abs(stas[i])))) # If there are multiple pixels with largest value, # take the first one. if max_i.shape != (3, ): try: max_i = max_i[:, 0] # If max_i cannot be found just set it to zeros. except IndexError: max_i = np.array([0, 0, 0]) max_inds.append(max_i) print(f'Completed. Total elapsed time: {msc.timediff(startime)}\n' + f'Finished on {datetime.datetime.now().strftime("%A %X")}') savepath = os.path.join(exp_dir, 'data_analysis', stimname) if not os.path.isdir(savepath): os.makedirs(savepath, exist_ok=True) savepath = os.path.join(savepath, savefname) keystosave = [ 'clusters', 'frametimings', 'mask', 'repeated_spiketimes', 'run_spiketimes', 'frame_duration', 'max_inds', 'nblinks', 'stas', 'stx_h', 'stx_w', 'total_run_frames', 'sx', 'sy', 'filter_length', 'stimname', 'exp_name', 'spikenrs', 'clusterstoanalyze', 'frametimingsfraction', 'cutoff', 'quals', 'nrofchunks', 'chunklength' ] datadict = {} for key in keystosave: datadict[key] = locals()[key] np.savez(savepath, **datadict) t = (np.arange(nrofchunks) * chunklength * frame_duration) / refresh_rate qmax = np.max(quals, axis=0) qualsn = quals / qmax[np.newaxis, :] ax = plt.subplot(111) ax.plot(t, qualsn, alpha=0.3) plt.ylabel('Z-score of center pixel (normalized)') plt.xlabel('Minutes of stimulus analyzed') plt.ylim([0, 1]) plf.spineless(ax, 'tr') plt.title(f'Recording duration optimization\n{exp_name}\n {savefname}') plt.savefig(savepath + '.svg', format='svg') plt.close()
def plotcheckersvd(expname, stimnr, filename=None): """ Plot the first two components of SVD analysis. """ if filename: filename = str(filename) exp_dir = iof.exp_dir_fixer(expname) _, metadata = asc.read_spikesheet(exp_dir) px_size = metadata['pixel_size(um)'] if not filename: savefolder = 'SVD' label = '' else: label = filename.strip('.npz') savefolder = 'SVD_' + label data = iof.load(expname, stimnr, filename) stas = data['stas'] max_inds = data['max_inds'] clusters = data['clusters'] stx_h = data['stx_h'] frame_duration = data['frame_duration'] stimname = data['stimname'] exp_name = data['exp_name'] clusterids = plf.clusters_to_ids(clusters) # Determine frame size so that the total frame covers # an area large enough i.e. 2*700um f_size = int(700 / (stx_h * px_size)) for i in range(clusters.shape[0]): sta = stas[i] max_i = max_inds[i] try: sta, max_i = msc.cut_around_center(sta, max_i, f_size=f_size) except ValueError: continue fit_frame = sta[:, :, max_i[2]] try: sp1, sp2, t1, t2, _, _ = msc.svd(sta) # If the STA is noisy (msc.cut_around_center produces an empty array) # SVD cannot be calculated, in this case we skip that cluster. except np.linalg.LinAlgError: continue plotthese = [fit_frame, sp1, sp2] plt.figure(dpi=200) plt.suptitle(f'{exp_name}\n{stimname}\n{clusterids[i]}') rows = 2 cols = 3 vmax = np.max(np.abs([sp1, sp2])) vmin = -vmax for j in range(len(plotthese)): ax = plt.subplot(rows, cols, j + 1) im = plt.imshow(plotthese[j], vmin=vmin, vmax=vmax, cmap=iof.config('colormap')) ax.set_aspect('equal') plt.xticks([]) plt.yticks([]) for child in ax.get_children(): if isinstance(child, matplotlib.spines.Spine): child.set_color('C{}'.format(j % 3)) child.set_linewidth(2) if j == 0: plt.title('center px') elif j == 1: plt.title('SVD spatial 1') elif j == 2: plt.title('SVD spatial 2') plf.colorbar(im, ticks=[vmin, 0, vmax], format='%.2f') barsize = 100 / (stx_h * px_size) scalebar = AnchoredSizeBar(ax.transData, barsize, '100 µm', 'lower left', pad=0, color='k', frameon=False, size_vertical=.3) ax.add_artist(scalebar) t = np.arange(sta.shape[-1]) * frame_duration * 1000 plt.subplots_adjust(wspace=0.3, hspace=0) ax = plt.subplot(rows, 1, 2) plt.plot(t, sta[max_i[0], max_i[1], :], label='center px') plt.plot(t, t1, label='Temporal 1') plt.plot(t, t2, label='Temporal 2') plt.xlabel('Time[ms]') plf.spineless(ax, 'trlb') # Turn off spines using custom function plotpath = os.path.join(exp_dir, 'data_analysis', stimname, savefolder) if not os.path.isdir(plotpath): os.makedirs(plotpath, exist_ok=True) plt.savefig(os.path.join(plotpath, clusterids[i] + '.svg'), dpi=300) plt.close() print(f'Plotted checkerflicker SVD for {stimname}')
def plot_checker_stas(exp_name, stim_nr, filename=None): """ Plot and save all STAs from checkerflicker analysis. The plots will be saved in a new folder called STAs under the data analysis path of the stimulus. <exp_dir>/data_analysis/<stim_nr>_*/<stim_nr>_data.h5 file is used by default. If a different file is to be used, filename should be supplied. """ from mpl_toolkits.axes_grid1.anchored_artists import AnchoredSizeBar exp_dir = iof.exp_dir_fixer(exp_name) stim_nr = str(stim_nr) if filename: filename = str(filename) _, metadata = asc.read_spikesheet(exp_dir) px_size = metadata['pixel_size(um)'] if not filename: savefolder = 'STAs' label = '' else: label = filename.strip('.npz') savefolder = 'STAs_' + label data = iof.load(exp_name, stim_nr, fname=filename) clusters = data['clusters'] stas = data['stas'] filter_length = data['filter_length'] stx_h = data['stx_h'] exp_name = data['exp_name'] stimname = data['stimname'] for j in range(clusters.shape[0]): a = stas[j] subplot_arr = plf.numsubplots(filter_length) sta_max = np.max(np.abs([np.max(a), np.min(a)])) sta_min = -sta_max plt.figure(dpi=250) for i in range(filter_length): ax = plt.subplot(subplot_arr[0], subplot_arr[1], i + 1) im = ax.imshow(a[:, :, i], vmin=sta_min, vmax=sta_max, cmap=iof.config('colormap')) ax.set_aspect('equal') plt.axis('off') if i == 0: scalebar = AnchoredSizeBar(ax.transData, 10, '{} µm'.format(10 * stx_h * px_size), 'lower left', pad=0, color='k', frameon=False, size_vertical=1) ax.add_artist(scalebar) if i == filter_length - 1: plf.colorbar(im, ticks=[sta_min, 0, sta_max], format='%.2f') plt.suptitle('{}\n{}\n' '{:0>3}{:0>2} Rating: {}'.format(exp_name, stimname + label, clusters[j][0], clusters[j][1], clusters[j][2])) savepath = os.path.join( exp_dir, 'data_analysis', stimname, savefolder, '{:0>3}{:0>2}'.format(clusters[j][0], clusters[j][1])) os.makedirs(os.path.split(savepath)[0], exist_ok=True) plt.savefig(savepath + '.png', bbox_inches='tight') plt.close() print(f'Plotted checkerflicker STA for {stimname}')
def multistabrowser(stas, frame_duration=None, cmap=None, centerzero=True): """ Returns an interactive plot to browse multiple spatiotemporal STAs at the same time. Requires an interactive matplotlib backend. Parameters -------- stas: Numpy array containing STAs. First dimension should index individual cells, last dimension should index time. frame_duration: Time between each frame. (optional) cmap: Colormap to use. centerzero: Whether to center the colormap around zero for diverging colormaps. Example ------ >>> print(stas.shape) # (nrcells, xpixels, ypixels, time) (36, 75, 100, 40) >>> fig, slider = stabrowser(stas, frame_duration=1/60) Notes ----- When calling the function, the slider is returned to prevent the reference to it getting destroyed and to keep it interactive. The dummy variable `_` can also be used. """ interactive_backends = ['Qt', 'Tk'] backend = mpl.get_backend() if not backend[:2] in interactive_backends: raise ValueError('Switch to an interactive backend (e.g. Qt) to see' ' the animation.') if isinstance(stas, list): stas = np.array(stas) if cmap is None: cmap = iof.config('colormap') if centerzero: vmax = absmax(stas) vmin = absmin(stas) else: vmax, vmin = stas.max(), stas.min() imshowkwargs = dict(cmap=cmap, vmax=vmax, vmin=vmin) rows, cols = plf.numsubplots(stas.shape[0]) fig, axes = plt.subplots(rows, cols, sharex=True, sharey=True) initial_frame = 5 axsl = fig.add_axes([0.25, 0.05, 0.65, 0.03]) # For the slider to remain interactive, a reference to it should # be kept, so it is returned by the function slider_t = Slider(axsl, 'Frame before spike', 0, stas.shape[-1] - 1, valinit=initial_frame, valstep=1, valfmt='%2.0f') def update(frame): frame = int(frame) for i in range(rows): for j in range(cols): im = axes[i, j].get_images()[0] im.set_data(stas[i * rows + j, ..., frame]) if frame_duration is not None: fig.suptitle(f'{frame*frame_duration*1000:4.0f} ms') fig.canvas.draw_idle() slider_t.on_changed(update) for i in range(rows): for j in range(cols): ax = axes[i, j] ax.imshow(stas[i * rows + j, ..., initial_frame], **imshowkwargs) ax.set_axis_off() plt.tight_layout() plt.subplots_adjust(wspace=.01, hspace=.01) return fig, slider_t
def stabrowser(sta, frame_duration=None, cmap=None, centerzero=True, **kwargs): """ Returns an interactive plot to browse an spatiotemporal STA. Requires an interactive matplotlib backend. Parameters -------- sta: Numpy array containing the STA. Last dimension should index time. frame_duration: Time between each frame. (optional) cmap: Colormap to use. centerzero: Whether to center the colormap around zero for diverging colormaps. Example ------ >>> print(sta.shape) # (xpixels, ypixels, time) (75, 100, 40) >>> fig, slider = stabrowser(sta, frame_duration=1/60) Notes ----- When calling the function, the slider is returned to prevent the reference to it getting destroyed and to keep it interactive. The dummy variable `_` can also be used. """ check_interactive_backend() if cmap is None: cmap = iof.config('colormap') if centerzero: vmax = asc.absmax(sta) vmin = asc.absmin(sta) else: vmax, vmin = sta.max(), sta.min() imshowkwargs = dict(cmap=cmap, vmax=vmax, vmin=vmin, **kwargs) fig = plt.figure() ax = fig.add_axes([0.1, 0.1, 0.8, 0.8]) initial_frame = 5 axsl = fig.add_axes([0.25, 0.05, 0.65, 0.03]) # For the slider to remain interactive, a reference to it should # be kept, so it set to a variable and is returned by the function slider_t = Slider(axsl, 'Frame before spike', 0, sta.shape[-1] - 1, valinit=initial_frame, valstep=1, valfmt='%2.0f') def update(frame): frame = int(frame) im = ax.get_images()[0] im.set_data(sta[..., frame]) if frame_duration is not None: fig.suptitle(f'{frame*frame_duration*1000:4.0f} ms') fig.canvas.draw_idle() slider_t.on_changed(update) ax.imshow(sta[..., initial_frame], **imshowkwargs) ax.set_axis_off() plt.tight_layout() plt.subplots_adjust(wspace=.01, hspace=.01) return fig, slider_t
i = 63 #%% sta = stas[i] max_i = max_inds[i] fit_frame = sta[:, :, max_i[2]] stac, max_i = msc.cut_around_center(sta, max_i, f_size + 2) sp1, sp2, t1, t2, _, _ = msc.svd(stac) #%% plt.figure(figsize=(12, 10)) vmax = np.max(np.abs([sp1, sp2])) vmin = -vmax plt.subplot(131) plt.imshow(sp1, cmap=iof.config('colormap'), vmin=vmin, vmax=vmax) plt.subplot(132) plt.imshow(sp2, cmap=iof.config('colormap'), vmin=vmin, vmax=vmax) plt.subplot(133) im = plt.imshow(fit_frame, cmap=iof.config('colormap'), vmin=vmin, vmax=vmax) plf.colorbar(im, size='2%', ticks=[vmin, vmax], format='%.2f') plt.show() sp1c, maxic = msc.cut_around_center(sp1, max_i, f_size) sp2c, maxic = msc.cut_around_center(sp2, max_i, f_size) fit_framec, maxic = msc.cut_around_center(fit_frame, max_i, f_size) plt.figure(figsize=(12, 10)) vmax = np.max(np.abs([sp1c, sp2c])) vmin = -vmax plt.subplot(131)