def load_raw_data(self): # TODO: account for these temporary hardcoded params n_channels = self.num_channels dtype = np.int16 sample_rate = self.params.fs raw_data = get_ephys_reader(self.data_path, sample_rate=sample_rate, dtype=dtype, n_channels=n_channels) self.raw_data = raw_data
def create_trace_gui(obj, **kwargs): """Create the Trace GUI. Parameters ---------- obj : str or Path Path to the raw data file. sample_rate : float The data sampling rate, in Hz. n_channels_dat : int The number of columns in the raw data file. dtype : str The NumPy data type of the raw binary file. offset : int The header offset in bytes. """ gui_name = 'TraceGUI' # Support passing a params.py file. if str(obj).endswith('.py'): params = get_template_params(str(obj)) return create_trace_gui(next(iter(params.pop('dat_path'))), **params) kwargs = { k: v for k, v in kwargs.items() if k in ('sample_rate', 'n_channels_dat', 'dtype', 'offset') } traces = get_ephys_reader(obj, **kwargs) create_app() gui = GUI(name=gui_name, subtitle=obj.resolve(), enable_threading=False) gui.set_default_actions() def _get_traces(interval): return Bunch(data=select_traces( traces, interval, sample_rate=traces.sample_rate)) # TODO: load channel information view = TraceView( traces=_get_traces, n_channels=traces.n_channels, sample_rate=traces.sample_rate, duration=traces.duration, enable_threading=False, ) view.attach(gui) return gui
def run(dat_path: str = None, dir_path: Path = None, output_dir: Path = None, probe=None, params=None, stop_after=None, clear_context=False, **kwargs): """Launch KiloSort 2. probe has the following attributes: - xc - yc - kcoords - Nchan """ # Get or create the probe object. if isinstance(probe, (str, Path)): probe = load_probe(probe) raw_data = get_ephys_reader(dat_path, **kwargs) assert raw_data.ndim == 2 # Now, the initial raw data must be in C order, it will be converted to Fortran order # in the proc file step, so as to use the existing CUDA kernels from MATLAB. assert raw_data.shape[0] > raw_data.shape[1] # nsamples > nchannels n_samples, n_channels = raw_data.shape logger.info("Loaded raw data with %d channels, %d samples.", n_channels, n_samples) # TODO: design - let's pass in all of the config already parsed and ready into this function # - run should do 1 thing only - run the steps of the algorithm. # Get probe. probe = probe or default_probe(raw_data) assert probe # Get params. if not isinstance(params, BaseModel): params = KilosortParams(**params or {}) assert params # dir path dir_path = dir_path or Path(dat_path).parent assert dir_path, "Please provide a dir_path" dir_path.mkdir(exist_ok=True, parents=True) assert dir_path.exists() # Create the context. ctx_path = dir_path / ".kilosort" / raw_data.name if clear_context: logger.info(f"Clearing context at {ctx_path} ...") shutil.rmtree(ctx_path, ignore_errors=True) ctx = Context(ctx_path) ctx.params = params ctx.probe = probe ctx.raw_data = raw_data # Load the intermediate results to avoid recomputing things. ctx.load() # TODO: unclear - what if we have changed something e.g. a parameter? Shouldn't # - we make the path depdendent on at least the hash of the params? # - We should also be able to turn this off for easy testing / experimentation. ir = ctx.intermediate ir.Nbatch = get_Nbatch(raw_data, params) # ------------------------------------------------------------------------- # Find good channels. # NOTE: now we use C order from loading up to the creation of the proc file, which is # in Fortran order. if params.minfr_goodchannels > 0: # discard channels that have very few spikes if "igood" not in ir: # determine bad channels with ctx.time("good_channels"): ir.igood = get_good_channels(raw_data=raw_data, probe=probe, params=params) # Cache the result. ctx.write(igood=ir.igood) if stop_after == "good_channels": return ctx # it's enough to remove bad channels from the channel map, which treats them # as if they are dead ir.igood = ir.igood.ravel() probe.chanMap = probe.chanMap[ir.igood] probe.xc = probe.xc[ir.igood] # removes coordinates of bad channels probe.yc = probe.yc[ir.igood] probe.kcoords = probe.kcoords[ir.igood] probe.Nchan = len( probe.chanMap) # total number of good channels that we will spike sort assert probe.Nchan > 0 # upper bound on the number of templates we can have params.Nfilt = params.nfilt_factor * probe.Nchan # ------------------------------------------------------------------------- # Find the whitening matrix. if "Wrot" not in ir: # outputs a rotation matrix (Nchan by Nchan) which whitens the zero-timelag covariance # of the data with ctx.time("whitening_matrix"): ir.Wrot = get_whitening_matrix(raw_data=raw_data, probe=probe, params=params) # Cache the result. ctx.write(Wrot=ir.Wrot) if stop_after == "whitening_matrix": return ctx # ------------------------------------------------------------------------- # Preprocess data to create proc.dat ir.proc_path = ctx.path("proc", ".dat") if not ir.proc_path.exists(): # Do not preprocess again if the proc.dat file already exists. with ctx.time("preprocess"): preprocess(ctx) if stop_after == "preprocess": return ctx # Open the proc file. # NOTE: now we are always in Fortran order. assert ir.proc_path.exists() ir.proc = np.memmap(ir.proc_path, dtype=raw_data.dtype, mode="r", order="F") # ------------------------------------------------------------------------- # Time-reordering as a function of drift. # # This function saves: # # iorig, ccb0, ccbsort # if "iorig" not in ir: with ctx.time("reorder"): out = clusterSingleBatches(ctx) ctx.save(**out) if stop_after == "reorder": return ctx # ------------------------------------------------------------------------- # Main tracking and template matching algorithm. # # This function uses: # # procfile # iorig # # This function saves: # # wPCA, wTEMP # st3, simScore, # cProj, cProjPC, # iNeigh, iNeighPC, # WA, UA, W, U, dWU, mu, # W_a, W_b, U_a, U_b # if "st3" not in ir: with ctx.time("learn"): out = learnAndSolve8b(ctx) logger.info("%d spikes.", ir.st3.shape[0]) ctx.save(**out) if stop_after == "learn": return ctx # Special care for cProj and cProjPC which are memmapped .dat files. ir.cProj = memmap_large_array(ctx.path("fW", ext=".dat")).T ir.cProjPC = memmap_large_array(ctx.path("fWpc", ext=".dat")).T # transpose # ------------------------------------------------------------------------- # Final merges. # # This function uses: # # st3, simScore # # This function saves: # # st3_m, # R_CCG, Q_CCG, K_CCG [optional] # if "st3_m" not in ir: with ctx.time("merge"): out = find_merges(ctx, True) ctx.save(**out) if stop_after == "merge": return ctx # ------------------------------------------------------------------------- # Final splits. # # This function uses: # # st3_m # W, dWU, cProjPC, # iNeigh, simScore # wPCA # # This function saves: # # st3_s # W_s, U_s, mu_s, simScore_s # iNeigh_s, iNeighPC_s, # Wphy, iList, isplit # if "st3_s1" not in ir: # final splits by SVD with ctx.time("split_1"): out = splitAllClusters(ctx, True) # Use a different name for both splitting steps. out["st3_s1"] = out.pop("st3_s") ctx.save(**out) if stop_after == "split_1": return ctx if "st3_s0" not in ir: # final splits by amplitudes with ctx.time("split_2"): out = splitAllClusters(ctx, False) out["st3_s0"] = out.pop("st3_s") ctx.save(**out) if stop_after == "split_2": return ctx # ------------------------------------------------------------------------- # Decide on cutoff. # # This function uses: # # st3_s # dWU, cProj, cProjPC # wPCA # # This function saves: # # st3_c, spikes_to_remove, # est_contam_rate, Ths, good # if "st3_c" not in ir: with ctx.time("cutoff"): out = set_cutoff(ctx) ctx.save(**out) if stop_after == "cutoff": return ctx logger.info("%d spikes after cutoff.", ir.st3_c.shape[0]) logger.info("Found %d good units.", np.sum(ir.good > 0)) # write to Phy logger.info("Saving results to phy.") output_dir = output_dir or f"{dir_path}/output" with ctx.time("output"): rezToPhy(ctx, dat_path=dat_path, output_dir=output_dir) # Show timing information. ctx.show_timer() ctx.write(timer=ctx.timer) return ctx
import matplotlib.pyplot as plt from hyb_clu import hyb_clu import numpy as np import pandas as pd import split import time import os import importlib import sys # Load raw traces for calculation of SNR trace_args = ['sample_rate', 'n_channels_dat', 'dtype', 'offset'] trace_vals = [[25000], [64], [np.int16], [0]] kwargs = {trace_args[x]: trace_vals[x][0] for x in range(len(trace_args))} data_dir = os.getcwd() + r'\ConcatenatedData_Probe1.GT.bin' traces = get_ephys_reader(data_dir, **kwargs) pbounds = slice(traces.part_bounds[0], traces.part_bounds[1]) traces = traces._get_part(pbounds, 0) # Noise estimate for spikesorting data noise_est = lambda data: np.median((np.abs(data)) / 0.6745) # Replace stdout with text file. timestr = time.strftime("splitter_%Y%m%d-%H%M%S") orig_stdout = sys.stdout sys.stdout = open(timestr + '.txt', 'w') # For development pipeline convenience importlib.reload(split) # First, data directories, ground truth (GT) clusters, artificially added clusters,