def gen_synth_datasets(datasets, *, outdir, samplerate=32000): if not os.path.exists(outdir): os.mkdir(outdir) for ds in datasets: ds_name = ds['name'] print(ds_name) if 'seed' not in ds.keys(): ds['seed'] = 0 spiketrains = gen_spiketrains( duration=ds['duration'], n_exc=ds['n_exc'], n_inh=ds['n_inh'], f_exc=ds['f_exc'], f_inh=ds['f_inh'], min_rate=ds['min_rate'], st_exc=ds['st_exc'], st_inh=ds['st_inh'], seed=ds['seed'] ) OX = NeoSpikeTrainsOutputExtractor( spiketrains=spiketrains, samplerate=samplerate) X, geom = gen_recording( templates=ds['templates'], output_extractor=OX, noise_level=ds['noise_level'], samplerate=samplerate, duration=ds['duration'] ) IX = si.NumpyRecordingExtractor( timeseries=X, samplerate=samplerate, geom=geom) SFMdaRecordingExtractor.write_recording( IX, outdir+'/{}'.format(ds_name)) SFMdaSortingExtractor.write_sorting( OX, outdir+'/{}/firings_true.mda'.format(ds_name)) print('Done.')
def _generate_toy_recordings(): # generate toy recordings if not os.path.exists('toy_recordings'): os.mkdir('toy_recordings') replace_recordings = False ret = [] for K in [5, 10, 15, 20]: recpath = 'toy_recordings/example_K{}'.format(K) if os.path.exists(recpath) and (replace_recordings): print('Generating toy recording: {}'.format(recpath)) shutil.rmtree(recpath) else: print('Recording already exists: {}'.format(recpath)) if not os.path.exists(recpath): rx, sx_true = example_datasets.toy_example1(duration=60, num_channels=4, samplerate=30000, K=K) SFMdaRecordingExtractor.write_recording(recording=rx, save_path=recpath) SFMdaSortingExtractor.write_sorting(sorting=sx_true, save_path=recpath + '/firings_true.mda') ret.append( dict(name='example_K{}'.format(K), study='toy_study', directory=os.path.abspath(recpath), description='A toy recording with K={} units'.format(K))) return ret
def jrclust_helper( *, recording, # Recording object tmpdir, # Temporary working directory params=dict(), **kwargs): jrclust_path = os.environ.get('JRCLUST_PATH_DEV', None) if jrclust_path: print('Using jrclust from JRCLUST_PATH_DEV directory: {}'.format( jrclust_path)) else: try: print('Auto-installing jrclust.') jrclust_path = install_jrclust( repo='https://github.com/JaneliaSciComp/JRCLUST.git', commit='3d2e75c0041dca2a9f273598750c6a14dbc4c1b8') except: traceback.print_exc() raise Exception( 'Problem installing jrclust. You can set the JRCLUST_PATH_DEV to force to use a particular path.' ) print('Using jrclust from: {}'.format(jrclust_path)) dataset_dir = os.path.join(tmpdir, 'jrclust_dataset') # Generate three files in the dataset directory: raw.mda, geom.csv, params.json SFMdaRecordingExtractor.write_recording(recording=recording, save_path=dataset_dir, params=params, _preserve_dtype=True) samplerate = recording.get_sampling_frequency() print('Reading timeseries header...') raw_mda = os.path.join(dataset_dir, 'raw.mda') HH = mdaio.readmda_header(raw_mda) num_channels = HH.dims[0] num_timepoints = HH.dims[1] duration_minutes = num_timepoints / samplerate / 60 print('Num. channels = {}, Num. timepoints = {}, duration = {} minutes'. format(num_channels, num_timepoints, duration_minutes)) print('Creating argfile.txt...') txt = '' for key0, val0 in kwargs.items(): txt += '{}={}\n'.format(key0, val0) if 'scale_factor' in params: txt += 'bitScaling={}\n'.format(params["scale_factor"]) txt += 'sampleRate={}\n'.format(samplerate) _write_text_file(dataset_dir + '/argfile.txt', txt) # new method source_path = os.path.dirname(os.path.realpath(__file__)) print('Running jrclust in {tmpdir}...'.format(tmpdir=tmpdir)) cmd = ''' addpath('{jrclust_path}', '{source_path}', '{source_path}/mdaio'); try p_jrclust('{tmpdir}', '{dataset_dir}/raw.mda', '{dataset_dir}/geom.csv', '{tmpdir}/firings.mda', '{dataset_dir}/argfile.txt'); catch fprintf('----------------------------------------'); fprintf(lasterr()); quit(1); end quit(0); ''' cmd = cmd.format(jrclust_path=jrclust_path, tmpdir=tmpdir, dataset_dir=dataset_dir, source_path=source_path) matlab_cmd = mlpr.ShellScript(cmd, script_path=tmpdir + '/run_jrclust.m', keep_temp_files=True) matlab_cmd.write() shell_cmd = ''' #!/bin/bash cd {tmpdir} matlab -nosplash -nodisplay -r run_jrclust '''.format(tmpdir=tmpdir) shell_cmd = mlpr.ShellScript(shell_cmd, script_path=tmpdir + '/run_jrclust.sh', keep_temp_files=True) shell_cmd.write(tmpdir + '/run_jrclust.sh') time_ = time.time() shell_cmd.start() retcode = shell_cmd.wait() print('#SF-SORTER-RUNTIME#{:.3f}#'.format(time_ - time.time())) if retcode != 0: raise Exception('jrclust returned a non-zero exit code') # parse output result_fname = tmpdir + '/firings.mda' if not os.path.exists(result_fname): raise Exception('Result file does not exist: ' + result_fname) firings = mdaio.readmda(result_fname) sorting = se.NumpySortingExtractor() sorting.set_times_labels(firings[1, :], firings[2, :]) return sorting
def kilosort2_helper( *, recording, # Recording object tmpdir, # Temporary working directory detect_sign=-1, # Polarity of the spikes, -1, 0, or 1 adjacency_radius=-1, # Channel neighborhood adjacency radius corresponding to geom file detect_threshold=6, # Threshold for detection merge_thresh=.98, # Cluster merging threhold 0..1 freq_min=150, # Lower frequency limit for band-pass filter freq_max=6000, # Upper frequency limit for band-pass filter pc_per_chan=3, # number of PC per chan minFR=1 / 50): # # TODO: do not require ks2 to depend on irc -- rather, put all necessary .m code in the spikeforest repo # ironclust_path = os.environ.get('IRONCLUST_PATH_DEV', None) # if ironclust_path: # print('Using ironclust from IRONCLUST_PATH_DEV directory: {}'.format(ironclust_path)) # else: # try: # print('Auto-installing ironclust.') # ironclust_path = install_ironclust(commit='042b600b014de13f6d11d3b4e50e849caafb4709') # except: # traceback.print_exc() # raise Exception('Problem installing ironclust. You can set the IRONCLUST_PATH_DEV to force to use a particular path.') # print('For kilosort2, using ironclust utility functions from: {}'.format(ironclust_path)) kilosort2_path = os.environ.get('KILOSORT2_PATH_DEV', None) if kilosort2_path: print('Using kilosort2 from KILOSORT2_PATH_DEV directory: {}'.format( kilosort2_path)) else: try: print('Auto-installing kilosort2.') kilosort2_path = KiloSort2.install() except: traceback.print_exc() raise Exception( 'Problem installing kilosort2. You can set the KILOSORT2_PATH_DEV to force to use a particular path.' ) print('Using kilosort2 from: {}'.format(kilosort2_path)) source_dir = os.path.dirname(os.path.realpath(__file__)) dataset_dir = tmpdir + '/kilosort2_dataset' # Generate three files in the dataset directory: raw.mda, geom.csv, params.json SFMdaRecordingExtractor.write_recording(recording=recording, save_path=dataset_dir, _preserve_dtype=True) samplerate = recording.get_sampling_frequency() print('Reading timeseries header...') HH = mdaio.readmda_header(dataset_dir + '/raw.mda') num_channels = HH.dims[0] num_timepoints = HH.dims[1] duration_minutes = num_timepoints / samplerate / 60 print('Num. channels = {}, Num. timepoints = {}, duration = {} minutes'. format(num_channels, num_timepoints, duration_minutes)) print('Creating argfile.txt file...') txt = '' txt += 'samplerate={}\n'.format(samplerate) txt += 'detect_sign={}\n'.format(detect_sign) txt += 'adjacency_radius={}\n'.format(adjacency_radius) txt += 'detect_threshold={}\n'.format(detect_threshold) txt += 'merge_thresh={}\n'.format(merge_thresh) txt += 'freq_min={}\n'.format(freq_min) txt += 'freq_max={}\n'.format(freq_max) txt += 'pc_per_chan={}\n'.format(pc_per_chan) txt += 'minFR={}\n'.format(minFR) _write_text_file(dataset_dir + '/argfile.txt', txt) print('Running Kilosort2 in {tmpdir}...'.format(tmpdir=tmpdir)) cmd = ''' addpath('{source_dir}'); addpath('{source_dir}/mdaio') try p_kilosort2('{ksort}', '{tmpdir}', '{raw}', '{geom}', '{firings}', '{arg}'); catch quit(1); end quit(0); ''' cmd = cmd.format(source_dir=source_dir, ksort=kilosort2_path, tmpdir=tmpdir, raw=dataset_dir + '/raw.mda', geom=dataset_dir + '/geom.csv', firings=tmpdir + '/firings.mda', arg=dataset_dir + '/argfile.txt') matlab_cmd = mlpr.ShellScript(cmd, script_path=tmpdir + '/run_kilosort2.m', keep_temp_files=True) matlab_cmd.write() shell_cmd = ''' #!/bin/bash cd {tmpdir} echo '=====================' `date` '=====================' matlab -nosplash -nodisplay -r run_kilosort2 '''.format(tmpdir=tmpdir) shell_cmd = mlpr.ShellScript(shell_cmd, script_path=tmpdir + '/run_kilosort2.sh', keep_temp_files=True) shell_cmd.write(tmpdir + '/run_kilosort2.sh') shell_cmd.start() retcode = shell_cmd.wait() if retcode != 0: raise Exception('kilosort2 returned a non-zero exit code') # parse output result_fname = tmpdir + '/firings.mda' if not os.path.exists(result_fname): raise Exception('Result file does not exist: ' + result_fname) firings = mdaio.readmda(result_fname) sorting = se.NumpySortingExtractor() sorting.set_times_labels(firings[1, :], firings[2, :]) return sorting
def ironclust_helper( *, recording, # Recording object tmpdir, # Temporary working directory params=dict(), ironclust_path, **kwargs): source_dir = os.path.dirname(os.path.realpath(__file__)) dataset_dir = tmpdir + '/ironclust_dataset' # Generate three files in the dataset directory: raw.mda, geom.csv, params.json SFMdaRecordingExtractor.write_recording( recording=recording, save_path=dataset_dir, params=params, _preserve_dtype=True) samplerate = recording.get_sampling_frequency() print('Reading timeseries header...') HH = mdaio.readmda_header(dataset_dir + '/raw.mda') num_channels = HH.dims[0] num_timepoints = HH.dims[1] duration_minutes = num_timepoints / samplerate / 60 print('Num. channels = {}, Num. timepoints = {}, duration = {} minutes'.format( num_channels, num_timepoints, duration_minutes)) print('Creating argfile.txt...') txt = '' for key0, val0 in kwargs.items(): txt += '{}={}\n'.format(key0, val0) txt += 'samplerate={}\n'.format(samplerate) if 'scale_factor' in params: txt += 'scale_factor={}\n'.format(params["scale_factor"]) _write_text_file(dataset_dir + '/argfile.txt', txt) # new method print('Running ironclust in {tmpdir}...'.format(tmpdir=tmpdir)) cmd = ''' addpath('{source_dir}'); addpath('{ironclust_path}', '{ironclust_path}/matlab', '{ironclust_path}/matlab/mdaio'); try p_ironclust('{tmpdir}', '{dataset_dir}/raw.mda', '{dataset_dir}/geom.csv', '', '', '{tmpdir}/firings.mda', '{dataset_dir}/argfile.txt'); catch fprintf('----------------------------------------'); fprintf(lasterr()); quit(1); end quit(0); ''' cmd = cmd.format(ironclust_path=ironclust_path, tmpdir=tmpdir, dataset_dir=dataset_dir, source_dir=source_dir) matlab_cmd = mlpr.ShellScript(cmd, script_path=tmpdir + '/run_ironclust.m', keep_temp_files=True) matlab_cmd.write() shell_cmd = ''' #!/bin/bash cd {tmpdir} matlab -nosplash -nodisplay -r run_ironclust '''.format(tmpdir=tmpdir) shell_cmd = mlpr.ShellScript(shell_cmd, script_path=tmpdir + '/run_ironclust.sh', keep_temp_files=True) shell_cmd.write(tmpdir + '/run_ironclust.sh') shell_cmd.start() retcode = shell_cmd.wait() if retcode != 0: raise Exception('ironclust returned a non-zero exit code') # parse output result_fname = tmpdir + '/firings.mda' if not os.path.exists(result_fname): raise Exception('Result file does not exist: ' + result_fname) firings = mdaio.readmda(result_fname) sorting = se.NumpySortingExtractor() sorting.set_times_labels(firings[1, :], firings[2, :]) return sorting
def waveclus_helper( *, recording, # Recording object tmpdir, # Temporary working directory params=dict(), **kwargs): waveclus_path = os.environ.get('WAVECLUS_PATH_DEV', None) if waveclus_path: print('Using waveclus from WAVECLUS_PATH_DEV directory: {}'.format( waveclus_path)) else: try: print('Auto-installing waveclus.') waveclus_path = install_waveclus( repo='https://github.com/csn-le/wave_clus.git', commit='248d15c7eaa2b45b15e4488dfb9b09bfe39f5341') except: traceback.print_exc() raise Exception( 'Problem installing waveclus. You can set the WAVECLUS_PATH_DEV to force to use a particular path.' ) print('Using waveclus from: {}'.format(waveclus_path)) dataset_dir = os.path.join(tmpdir, 'waveclus_dataset') # Generate three files in the dataset directory: raw.mda, geom.csv, params.json SFMdaRecordingExtractor.write_recording(recording=recording, save_path=dataset_dir, params=params, _preserve_dtype=True) samplerate = recording.get_sampling_frequency() print('Reading timeseries header...') raw_mda = os.path.join(dataset_dir, 'raw.mda') HH = mdaio.readmda_header(raw_mda) num_channels = HH.dims[0] num_timepoints = HH.dims[1] duration_minutes = num_timepoints / samplerate / 60 print('Num. channels = {}, Num. timepoints = {}, duration = {} minutes'. format(num_channels, num_timepoints, duration_minutes)) # new method source_path = os.path.dirname(os.path.realpath(__file__)) print('Running waveclus in {tmpdir}...'.format(tmpdir=tmpdir)) cmd = ''' addpath(genpath('{waveclus_path}'), '{source_path}', '{source_path}/mdaio'); try p_waveclus('{tmpdir}', '{dataset_dir}/raw.mda', '{tmpdir}/firings.mda', {samplerate}); catch fprintf('----------------------------------------'); fprintf(lasterr()); quit(1); end quit(0); ''' cmd = cmd.format(waveclus_path=waveclus_path, tmpdir=tmpdir, dataset_dir=dataset_dir, source_path=source_path, samplerate=samplerate) matlab_cmd = mlpr.ShellScript(cmd, script_path=tmpdir + '/run_waveclus.m', keep_temp_files=True) matlab_cmd.write() shell_cmd = ''' #!/bin/bash cd {tmpdir} matlab -nosplash -nodisplay -r run_waveclus '''.format(tmpdir=tmpdir) shell_cmd = mlpr.ShellScript(shell_cmd, script_path=tmpdir + '/run_waveclus.sh', keep_temp_files=True) shell_cmd.write(tmpdir + '/run_waveclus.sh') time_ = time.time() shell_cmd.start() retcode = shell_cmd.wait() print('#SF-SORTER-RUNTIME#{:.3f}#'.format(time_ - time.time())) if retcode != 0: raise Exception('waveclus returned a non-zero exit code') # parse output result_fname = tmpdir + '/firings.mda' if not os.path.exists(result_fname): raise Exception('Result file does not exist: ' + result_fname) firings = mdaio.readmda(result_fname) sorting = se.NumpySortingExtractor() sorting.set_times_labels(firings[1, :], firings[2, :]) return sorting
import os import shutil from spikeforest import example_datasets from spikeforest import SFMdaRecordingExtractor, SFMdaSortingExtractor recording, sorting_true = example_datasets.toy_example1() recdir = 'toy_example1' # remove the toy recording directory if it exists if os.path.exists(recdir): shutil.rmtree(recdir) print('Preparing toy recording...') SFMdaRecordingExtractor.write_recording(recording=recording, save_path=recdir) SFMdaSortingExtractor.write_sorting(sorting=sorting_true, save_path=recdir + '/firings_true.mda')
#from mountaintools import client as ca from spikeforest import SFMdaRecordingExtractor, SFMdaSortingExtractor, example_datasets import os from spikeforestsorters import YASS from spikeforest import spikewidgets as sw # %%SortingComparisonyass_test1/recording/raw.mda tmpdir = 'yass_test1' if not os.path.isdir(tmpdir): os.mkdir(tmpdir) rx, sx = example_datasets.yass_example(set_id=1) # %% firings_true = tmpdir + '/recording/firings_true.mda' recording_path = tmpdir + '/recording' SFMdaRecordingExtractor.write_recording(recording=rx, save_path=recording_path) SFMdaSortingExtractor.write_sorting(sorting=sx, save_path=firings_true) YASS.execute(recording_dir=tmpdir + '/recording', firings_out=tmpdir + '/firings_out.mda', detect_sign=-1, adjacency_radius=50, _container=None, _force_run=True, _keep_temp_files=True) firings_out = tmpdir + '/firings_out.mda' assert os.path.exists(firings_out) # %% print('recording: {}'.format(recording_path)) print('firings_out: {}'.format(firings_out))