class CreateWaveformsPlot(mlpr.Processor): NAME='CreateWaveformsPlot' VERSION='0.1.0' recording_dir=mlpr.Input(directory=True,description='Recording directory') channels=mlpr.IntegerListParameter(description='List of channels to use.',optional=True,default=[]) units=mlpr.IntegerListParameter(description='List of units to use.',optional=True,default=[]) firings=mlpr.Input(description='Firings file') jpg_out=mlpr.Output('The plot as a .jpg file') def run(self): R0=si.MdaRecordingExtractor(dataset_directory=self.recording_dir,download=True) if len(self.channels)>0: R0=si.SubRecordingExtractor(parent_recording=R0,channel_ids=self.channels) R=sw.lazyfilters.bandpass_filter(recording=R0,freq_min=300,freq_max=6000) S=si.MdaSortingExtractor(firings_file=self.firings) channels=R.getChannelIds() if len(channels)>20: channels=channels[0:20] if len(self.units)>0: units=self.units else: units=S.getUnitIds() if len(units)>20: units=units[::int(len(units)/20)] sw.UnitWaveformsWidget(recording=R,sorting=S,channels=channels,unit_ids=units).plot() save_plot(self.jpg_out)
class ComputeUnitsInfo(mlpr.Processor): NAME = 'ComputeUnitsInfo' VERSION = '0.1.1' recording_dir = mlpr.Input(directory=True, description='Recording directory') channel_ids = mlpr.IntegerListParameter( description='List of channels to use.', optional=True, default=[]) unit_ids = mlpr.IntegerListParameter(description='List of units to use.', optional=True, default=[]) firings = mlpr.Input(description='Firings file') json_out = mlpr.Output('The info as a .json file') def run(self): R0 = si.MdaRecordingExtractor(dataset_directory=self.recording_dir, download=True) if (self.channel_ids) and (len(self.channel_ids) > 0): R0 = si.SubRecordingExtractor(parent_recording=R0, channel_ids=self.channel_ids) recording = sw.lazyfilters.bandpass_filter(recording=R0, freq_min=300, freq_max=6000) sorting = si.MdaSortingExtractor(firings_file=self.firings) ef = int(1e6) recording_sub = si.SubRecordingExtractor(parent_recording=recording, start_frame=0, end_frame=ef) recording_sub = MemoryRecordingExtractor( parent_recording=recording_sub) sorting_sub = si.SubSortingExtractor(parent_sorting=sorting, start_frame=0, end_frame=ef) unit_ids = self.unit_ids if (not unit_ids) or (len(unit_ids) == 0): unit_ids = sorting.getUnitIds() channel_noise_levels = compute_channel_noise_levels( recording=recording) print('computing templates...') templates = compute_unit_templates(recording=recording_sub, sorting=sorting_sub, unit_ids=unit_ids) print('.') ret = [] for i, unit_id in enumerate(unit_ids): template = templates[i] info0 = dict() info0['unit_id'] = int(unit_id) info0['snr'] = compute_template_snr(template, channel_noise_levels) peak_channel_index = np.argmax(np.max(np.abs(template), axis=1)) info0['peak_channel'] = int( recording.getChannelIds()[peak_channel_index]) train = sorting.getUnitSpikeTrain(unit_id=unit_id) info0['num_events'] = int(len(train)) info0['firing_rate'] = float( len(train) / (recording.getNumFrames() / recording.getSamplingFrequency())) ret.append(info0) write_json_file(self.json_out, ret)
class ComputeUnitsInfo(mlpr.Processor): NAME='ComputeUnitsInfo' VERSION='0.1.5k' CONTAINER=_CONTAINER recording_dir=mlpr.Input(directory=True,description='Recording directory') channel_ids=mlpr.IntegerListParameter(description='List of channels to use.',optional=True,default=[]) unit_ids=mlpr.IntegerListParameter(description='List of units to use.',optional=True,default=[]) firings=mlpr.Input(description='Firings file') json_out=mlpr.Output('The info as a .json file') def run(self): import spikewidgets as sw R0=si.MdaRecordingExtractor(dataset_directory=self.recording_dir,download=True) if (self.channel_ids) and (len(self.channel_ids)>0): R0=si.SubRecordingExtractor(parent_recording=R0,channel_ids=self.channel_ids) recording = R0 # recording=sw.lazyfilters.bandpass_filter(recording=R0,freq_min=300,freq_max=6000) sorting=si.MdaSortingExtractor(firings_file=self.firings) unit_ids=self.unit_ids if (not unit_ids) or (len(unit_ids)==0): unit_ids=sorting.getUnitIds() channel_noise_levels=compute_channel_noise_levels(recording=recording) # No longer use subset to compute the templates templates=compute_unit_templates(recording=recording,sorting=sorting,unit_ids=unit_ids,max_num=100) ret=[] for i,unit_id in enumerate(unit_ids): template=templates[i] max_p2p_amps_on_channels=np.max(template,axis=1)-np.min(template,axis=1) peak_channel_index=np.argmax(max_p2p_amps_on_channels) peak_channel=recording.getChannelIds()[peak_channel_index] R1=si.SubRecordingExtractor(parent_recording=recording,channel_ids=[peak_channel_index]) R1f=sw.lazyfilters.bandpass_filter(recording=R1,freq_min=300,freq_max=6000) templates2=compute_unit_templates(recording=R1f,sorting=sorting,unit_ids=[unit_id],max_num=100) template2=templates2[0] info0=dict() info0['unit_id']=int(unit_id) info0['snr']=np.max(np.abs(template2))/channel_noise_levels[peak_channel_index] #info0['snr']=compute_template_snr(template,channel_noise_levels) #peak_channel_index=np.argmax(np.max(np.abs(template),axis=1)) info0['peak_channel']=int(recording.getChannelIds()[peak_channel]) train=sorting.getUnitSpikeTrain(unit_id=unit_id) info0['num_events']=int(len(train)) info0['firing_rate']=float(len(train)/(recording.getNumFrames()/recording.getSamplingFrequency())) ret.append(info0) write_json_file(self.json_out,ret)
class KiloSort(mlpr.Processor): NAME = 'KiloSort' VERSION = '0.2.0' # wrapper VERSION ADDITIONAL_FILES = ['*.m'] ENVIRONMENT_VARIABLES = [ 'NUM_WORKERS', 'MKL_NUM_THREADS', 'NUMEXPR_NUM_THREADS', 'OMP_NUM_THREADS'] CONTAINER = None CONTAINER_SHARE_ID = None recording_dir = mlpr.Input('Directory of recording', directory=True) channels = mlpr.IntegerListParameter( description='List of channels to use.', optional=True, default=[]) firings_out = mlpr.Output('Output firings file') detect_sign = mlpr.IntegerParameter( 'Use -1 or 1, depending on the sign of the spikes in the recording') adjacency_radius = mlpr.FloatParameter( 'Use -1 to include all channels in every neighborhood') detect_threshold = mlpr.FloatParameter( optional=True, default=3, description='') # prm_template_name=mlpr.StringParameter(optional=False,description='TODO') freq_min = mlpr.FloatParameter( optional=True, default=300, description='Use 0 for no bandpass filtering') freq_max = mlpr.FloatParameter( optional=True, default=6000, description='Use 0 for no bandpass filtering') merge_thresh = mlpr.FloatParameter( optional=True, default=0.98, description='TODO') pc_per_chan = mlpr.IntegerParameter( optional=True, default=3, description='TODO') def run(self): code = ''.join(random.choice(string.ascii_uppercase) for x in range(10)) tmpdir = os.environ.get('TEMPDIR', '/tmp')+'/kilosort-tmp-'+code try: recording = se.MdaRecordingExtractor(self.recording_dir) if len(self.channels) > 0: recording = se.SubRecordingExtractor( parent_recording=recording, channel_ids=self.channels) if not os.path.exists(tmpdir): os.mkdir(tmpdir) sorting = kilosort_helper( recording=recording, tmpdir=tmpdir, detect_sign=self.detect_sign, adjacency_radius=self.adjacency_radius, detect_threshold=self.detect_threshold, merge_thresh=self.merge_thresh, freq_min=self.freq_min, freq_max=self.freq_max, pc_per_chan=self.pc_per_chan ) se.MdaSortingExtractor.writeSorting( sorting=sorting, save_path=self.firings_out) except: if os.path.exists(tmpdir): shutil.rmtree(tmpdir) raise shutil.rmtree(tmpdir)
class CreateTimeseriesPlot(mlpr.Processor): NAME = 'CreateTimeseriesPlot' VERSION = '0.1.7' recording_dir = mlpr.Input(directory=True, description='Recording directory') channels = mlpr.IntegerListParameter( description='List of channels to use.', optional=True, default=[]) jpg_out = mlpr.Output('The plot as a .jpg file') def run(self): R0 = si.MdaRecordingExtractor(dataset_directory=self.recording_dir, download=False) if len(self.channels) > 0: R0 = si.SubRecordingExtractor(parent_recording=R0, channel_ids=self.channels) R = sw.lazyfilters.bandpass_filter(recording=R0, freq_min=300, freq_max=6000) N = R.getNumFrames() N2 = int(N / 2) channels = R.getChannelIds() if len(channels) > 20: channels = channels[0:20] sw.TimeseriesWidget(recording=R, trange=[N2 - 4000, N2 + 0], channels=channels, width=12, height=5).plot() save_plot(self.jpg_out)
class GenSortingComparisonTable(mlpr.Processor): VERSION = '0.2.6' firings = mlpr.Input('Firings file (sorting)') firings_true = mlpr.Input('True firings file') units_true = mlpr.IntegerListParameter('List of true units to consider') json_out = mlpr.Output( 'Table as .json file produced from pandas dataframe') html_out = mlpr.Output( 'Table as .html file produced from pandas dataframe') # CONTAINER = 'sha1://5627c39b9bd729fc011cbfce6e8a7c37f8bcbc6b/spikeforest_basic.simg' # CONTAINER = 'sha1://0944f052e22de0f186bb6c5cb2814a71f118f2d1/spikeforest_basic.simg' # MAY26JJJ CONTAINER = 'sha1://4904b8f914eb159618b6579fb9ba07b269bb2c61/06-26-2019/spikeforest_basic.simg' def run(self): print( 'GenSortingComparisonTable: firings={}, firings_true={}, units_true={}' .format(self.firings, self.firings_true, self.units_true)) sorting = SFMdaSortingExtractor(firings_file=self.firings) sorting_true = SFMdaSortingExtractor(firings_file=self.firings_true) if (self.units_true is not None) and (len(self.units_true) > 0): sorting_true = si.SubSortingExtractor(parent_sorting=sorting_true, unit_ids=self.units_true) SC = SortingComparison(sorting_true, sorting, delta_tp=30) df = get_comparison_data_frame(comparison=SC) # sw.SortingComparisonTable(comparison=SC).getDataframe() json = df.transpose().to_dict() html = df.to_html(index=False) _write_json_file(json, self.json_out) _write_json_file(html, self.html_out)
class YASS(mlpr.Processor): NAME = 'YASS' VERSION = '0.1.0' # used by container to pass the env variables ENVIRONMENT_VARIABLES = [ 'NUM_WORKERS', 'MKL_NUM_THREADS', 'NUMEXPR_NUM_THREADS', 'OMP_NUM_THREADS'] ADDITIONAL_FILES = ['*.yaml'] CONTAINER = 'sha1://087767605e10761331699dda29519444bbd823f4/02-12-2019/yass.simg' CONTAINER_SHARE_ID = '69432e9201d0' # place to look for container recording_dir = mlpr.Input('Directory of recording', directory=True) channels = mlpr.IntegerListParameter( description='List of channels to use.', optional=True, default=[]) firings_out = mlpr.Output('Output firings file') #paramfile_out = mlpr.Output('YASS yaml config file') detect_sign = mlpr.IntegerParameter(description='-1, 1, or 0') adjacency_radius = mlpr.FloatParameter( optional=True, default=100, description='Channel neighborhood adjacency radius corresponding to geom file') template_width_ms = mlpr.FloatParameter( optional=True, default=3, description='Spike width in milliseconds') filter = mlpr.BoolParameter(optional=True, default=True) def run(self): code = ''.join(random.choice(string.ascii_uppercase) for x in range(10)) tmpdir = os.environ.get('TEMPDIR', '/tmp')+'/yass-tmp-'+code #num_workers = os.environ.get('NUM_WORKERS', 1) #print('num_workers: {}'.format(num_workers)) try: recording = se.MdaRecordingExtractor(self.recording_dir) if len(self.channels) > 0: recording = se.SubRecordingExtractor( parent_recording=recording, channel_ids=self.channels) if not os.path.exists(tmpdir): os.mkdir(tmpdir) sorting, yaml_file = yass_helper( recording=recording, output_folder=tmpdir, probe_file=None, file_name=None, detect_sign=self.detect_sign, adjacency_radius=self.adjacency_radius, template_width_ms=self.template_width_ms, filter=self.filter) se.MdaSortingExtractor.writeSorting( sorting=sorting, save_path=self.firings_out) #shutil.copyfile(yaml_file, self.paramfile_out) except: if os.path.exists(tmpdir): shutil.rmtree(tmpdir) raise if not getattr(self, '_keep_temp_files', False): shutil.rmtree(tmpdir)
class PlotUnitWaveforms(mlpr.Processor): VERSION='0.1.0' recording_dir=mlpr.Input(directory=True,description='Recording directory') channels=mlpr.IntegerListParameter(description='List of channels to use.',optional=True,default=[]) firings=mlpr.Input('Firings file (sorting)') plot_out=mlpr.Output('Plot as .jpg image file') def run(self): recording=si.MdaRecordingExtractor(dataset_directory=self.recording_dir) if len(self.channels)>0: recording=si.SubRecordingExtractor(parent_recording=recording,channel_ids=self.channels) sorting=si.MdaSortingExtractor(firings_file=self.firings) sw.UnitWaveformsWidget(recording=recording,sorting=sorting).plot() fname=save_plot(self.plot_out)
class ComputeUnitsInfo(mlpr.Processor): NAME = 'ComputeUnitsInfo' VERSION = '0.1.8' CONTAINER = _CONTAINER recording_dir = mlpr.Input(directory=True, description='Recording directory') channel_ids = mlpr.IntegerListParameter( description='List of channels to use.', optional=True, default=[]) unit_ids = mlpr.IntegerListParameter(description='List of units to use.', optional=True, default=[]) firings = mlpr.Input(description='Firings file') json_out = mlpr.Output('The info as a .json file') def run(self): R0 = SFMdaRecordingExtractor(dataset_directory=self.recording_dir, download=True) sorting = SFMdaSortingExtractor(firings_file=self.firings) ret = compute_units_info(recording=R0, sorting=sorting, channel_ids=self.channel_ids, unit_ids=self.unit_ids) write_json_file(self.json_out, ret)
class SpykingCircus(mlpr.Processor): NAME='SpykingCircus' VERSION='0.1.2' recording_dir=mlpr.Input('Directory of recording',directory=True) channels=mlpr.IntegerListParameter(description='List of channels to use.',optional=True,default=[]) firings_out=mlpr.Output('Output firings file') detect_sign=mlpr.IntegerParameter(description='-1, 1, or 0') adjacency_radius=mlpr.FloatParameter(optional=True,default=100,description='Channel neighborhood adjacency radius corresponding to geom file') spike_thresh=mlpr.FloatParameter(optional=True,default=6,description='Threshold for detection') template_width_ms=mlpr.FloatParameter(optional=True,default=3,description='Spyking circus parameter') filter=mlpr.BoolParameter(optional=True,default=True) whitening_max_elts=mlpr.IntegerParameter(optional=True,default=1000,description='I believe it relates to subsampling and affects compute time') clustering_max_elts=mlpr.IntegerParameter(optional=True,default=10000,description='I believe it relates to subsampling and affects compute time') def run(self): code=''.join(random.choice(string.ascii_uppercase) for x in range(10)) tmpdir=os.environ.get('TEMPDIR','/tmp')+'/ironclust-tmp-'+code num_workers=os.environ.get('NUM_WORKERS',2) try: recording=si.MdaRecordingExtractor(self.recording_dir) if len(self.channels)>0: recording=si.SubRecordingExtractor(parent_recording=recording,channel_ids=self.channels) if not os.path.exists(tmpdir): os.mkdir(tmpdir) sorting=sf.sorters.spyking_circus( recording=recording, output_folder=tmpdir, probe_file=None, file_name=None, detect_sign=self.detect_sign, adjacency_radius=self.adjacency_radius, spike_thresh=self.spike_thresh, template_width_ms=self.template_width_ms, filter=self.filter, merge_spikes=True, n_cores=num_workers, electrode_dimensions=None, whitening_max_elts=self.whitening_max_elts, clustering_max_elts=self.clustering_max_elts ) si.MdaSortingExtractor.writeSorting(sorting=sorting,save_path=self.firings_out) except: if os.path.exists(tmpdir): shutil.rmtree(tmpdir) raise shutil.rmtree(tmpdir)
class IronClust(mlpr.Processor): NAME='IronClust' VERSION='4.2.6' recording_dir=mlpr.Input('Directory of recording',directory=True) channels=mlpr.IntegerListParameter(description='List of channels to use.',optional=True,default=[]) firings_out=mlpr.Output('Output firings file') detect_sign=mlpr.IntegerParameter('Use -1, 0, or 1, depending on the sign of the spikes in the recording') adjacency_radius=mlpr.FloatParameter('Use -1 to include all channels in every neighborhood') detect_threshold=mlpr.FloatParameter(optional=True,default=3,description='') prm_template_name=mlpr.StringParameter(optional=False,description='TODO') freq_min=mlpr.FloatParameter(optional=True,default=300,description='Use 0 for no bandpass filtering') freq_max=mlpr.FloatParameter(optional=True,default=6000,description='Use 0 for no bandpass filtering') merge_thresh=mlpr.FloatParameter(optional=True,default=0.98,description='TODO') pc_per_chan=mlpr.IntegerParameter(optional=True,default=3,description='TODO') def run(self): ironclust_src=os.environ.get('IRONCLUST_SRC',None) if not ironclust_src: raise Exception('Environment variable not set: IRONCLUST_SRC') code=''.join(random.choice(string.ascii_uppercase) for x in range(10)) tmpdir=os.environ.get('TEMPDIR','/tmp')+'/ironclust-tmp-'+code try: recording=si.MdaRecordingExtractor(self.recording_dir) if len(self.channels)>0: recording=si.SubRecordingExtractor(parent_recording=recording,channel_ids=self.channels) if not os.path.exists(tmpdir): os.mkdir(tmpdir) sorting=sf.sorters.ironclust( recording=recording, tmpdir=tmpdir, ## TODO detect_sign=self.detect_sign, adjacency_radius=self.adjacency_radius, detect_threshold=self.detect_threshold, merge_thresh=self.merge_thresh, freq_min=self.freq_min, freq_max=self.freq_max, pc_per_chan=self.pc_per_chan, prm_template_name=self.prm_template_name, ironclust_src=ironclust_src ) si.MdaSortingExtractor.writeSorting(sorting=sorting,save_path=self.firings_out) except: if os.path.exists(tmpdir): shutil.rmtree(tmpdir) raise shutil.rmtree(tmpdir)
class RepeatText(mlpr.Processor): textfile = mlpr.Input(help="input text file") textfile_out = mlpr.Output(help="output text file") num_repeats = mlpr.IntegerListParameter( help="Number of times to repeat the text") def run(self): assert self.num_repeats >= 0 with open(self.textfile, 'r') as f: txt = f.read() txt2 = '' for _ in range(self.num_repeats): txt2 = txt2 + txt with open(self.textfile_out, 'w') as f: f.write(txt2)
class PlotAutoCorrelograms(mlpr.Processor): NAME='spikeforest.PlotAutoCorrelograms' VERSION='0.1.0' recording_dir=mlpr.Input(directory=True,description='Recording directory') channels=mlpr.IntegerListParameter(description='List of channels to use.',optional=True,default=[]) firings=mlpr.Input('Firings file (sorting)') plot_out=mlpr.Output('Plot as .jpg image file') def run(self): recording=si.MdaRecordingExtractor(dataset_directory=self.recording_dir,download=False) if len(self.channels)>0: recording=si.SubRecordingExtractor(parent_recording=recording,channel_ids=self.channels) sorting=si.MdaSortingExtractor(firings_file=self.firings) sw.CrossCorrelogramsWidget(samplerate=recording.getSamplingFrequency(),sorting=sorting).plot() fname=save_plot(self.plot_out)
class ComputeRecordingInfo(mlpr.Processor): NAME='ComputeRecordingInfo' VERSION='0.1.0' recording_dir=mlpr.Input(directory=True,description='Recording directory') channels=mlpr.IntegerListParameter(description='List of channels to use.',optional=True,default=[]) json_out=mlpr.Output('Info in .json file') def run(self): ret={} recording=si.MdaRecordingExtractor(dataset_directory=self.recording_dir,download=False) if len(self.channels)>0: recording=si.SubRecordingExtractor(parent_recording=recording,channel_ids=self.channels) ret['samplerate']=recording.getSamplingFrequency() ret['num_channels']=len(recording.getChannelIds()) ret['duration_sec']=recording.getNumFrames()/ret['samplerate'] write_json_file(self.json_out,ret)
class GenSortingComparisonTableNew(mlpr.Processor): VERSION = '0.3.1' firings = mlpr.Input('Firings file (sorting)') firings_true = mlpr.Input('True firings file') units_true = mlpr.IntegerListParameter('List of true units to consider') json_out = mlpr.Output( 'Table as .json file produced from pandas dataframe') html_out = mlpr.Output( 'Table as .html file produced from pandas dataframe') # CONTAINER = 'sha1://5627c39b9bd729fc011cbfce6e8a7c37f8bcbc6b/spikeforest_basic.simg' # CONTAINER = 'sha1://0944f052e22de0f186bb6c5cb2814a71f118f2d1/spikeforest_basic.simg' # MAY26JJJ CONTAINER = 'sha1://4904b8f914eb159618b6579fb9ba07b269bb2c61/06-26-2019/spikeforest_basic.simg' def run(self): print( 'GenSortingComparisonTable: firings={}, firings_true={}, units_true={}' .format(self.firings, self.firings_true, self.units_true)) sorting = SFMdaSortingExtractor(firings_file=self.firings) sorting_true = SFMdaSortingExtractor(firings_file=self.firings_true) if (self.units_true is not None) and (len(self.units_true) > 0): sorting_true = si.SubSortingExtractor(parent_sorting=sorting_true, unit_ids=self.units_true) SC = st.comparison.compare_sorter_to_ground_truth( gt_sorting=sorting_true, tested_sorting=sorting, delta_time=0.3, min_accuracy=0, compute_misclassification=False, exhaustive_gt=False # Fix this in future ) df = pd.concat([SC.count, SC.get_performance()], axis=1).reset_index() df = df.rename(columns=dict(gt_unit_id='unit_id', fp='num_false_positives', fn='num_false_negatives', tested_id='best_unit', tp='num_matches')) df['matched_unit'] = df['best_unit'] df['f_p'] = 1 - df['precision'] df['f_n'] = 1 - df['recall'] # sw.SortingComparisonTable(comparison=SC).getDataframe() json = df.transpose().to_dict() html = df.to_html(index=False) _write_json_file(json, self.json_out) _write_json_file(html, self.html_out)
class GenSortingComparisonTable(mlpr.Processor): VERSION='0.1.1' firings=mlpr.Input('Firings file (sorting)') firings_true=mlpr.Input('True firings file') units_true=mlpr.IntegerListParameter('List of true units to consider') json_out=mlpr.Output('Table as .json file produced from pandas dataframe') html_out=mlpr.Output('Table as .html file produced from pandas dataframe') def run(self): sorting=si.MdaSortingExtractor(firings_file=self.firings) sorting_true=si.MdaSortingExtractor(firings_file=self.firings_true) if len(self.units_true)>0: sorting_true=si.SubSortingExtractor(parent_sorting=sorting_true,unit_ids=self.units_true) SC=st.comparison.SortingComparison(sorting_true,sorting) df=sw.SortingComparisonTable(comparison=SC).getDataframe() json=df.transpose().to_dict() html=df.to_html(index=False) _write_json_file(json,self.json_out) _write_json_file(html,self.html_out)
class MountainSort4(mlpr.Processor): NAME='MountainSort4' VERSION='4.0.1' recording_dir=mlpr.Input('Directory of recording',directory=True) channels=mlpr.IntegerListParameter(description='List of channels to use.',optional=True,default=[]) firings_out=mlpr.Output('Output firings file') detect_sign=mlpr.IntegerParameter('Use -1, 0, or 1, depending on the sign of the spikes in the recording') adjacency_radius=mlpr.FloatParameter('Use -1 to include all channels in every neighborhood') freq_min=mlpr.FloatParameter(optional=True,default=300,description='Use 0 for no bandpass filtering') freq_max=mlpr.FloatParameter(optional=True,default=6000,description='Use 0 for no bandpass filtering') whiten=mlpr.BoolParameter(optional=True,default=True,description='Whether to do channel whitening as part of preprocessing') clip_size=mlpr.IntegerParameter(optional=True,default=50,description='') detect_threshold=mlpr.FloatParameter(optional=True,default=3,description='') detect_interval=mlpr.IntegerParameter(optional=True,default=10,description='Minimum number of timepoints between events detected on the same channel') noise_overlap_threshold=mlpr.FloatParameter(optional=True,default=0.15,description='Use None for no automated curation') def run(self): recording=si.MdaRecordingExtractor(self.recording_dir) if len(self.channels)>0: recording=si.SubRecordingExtractor(parent_recording=recording,channel_ids=self.channels) num_workers=os.environ.get('NUM_WORKERS',None) if num_workers: num_workers=int(num_workers) sorting=sf.sorters.mountainsort4( recording=recording, detect_sign=self.detect_sign, adjacency_radius=self.adjacency_radius, freq_min=self.freq_min, freq_max=self.freq_max, whiten=self.whiten, clip_size=self.clip_size, detect_threshold=self.detect_threshold, detect_interval=self.detect_interval, noise_overlap_threshold=self.noise_overlap_threshold, num_workers=num_workers ) si.MdaSortingExtractor.writeSorting(sorting=sorting,save_path=self.firings_out)
class JRClust(mlpr.Processor): """ JRClust v4 wrapper written by J. James Jun, May 8, 2019 modified from `spikeforest/spikesorters/ironclust/ironclust.py` [Installation instruction in SpikeForest environment] 1. Run `git clone https://github.com/JaneliaSciComp/JRCLUST` 2. Activate conda environment for SpikeForest 3. Create `JRCLUST_PATH` and `MDAIO_PATH` 4. Flatiron execution: `module load matlab/R2019a cuda/10.0.130_410.48` DO NOT LOAD `module load gcc`. DO NOT USE `matlab/R2018b` (it will crash while creating a parpool). See: James Jun, et al. Real-time spike sorting platform for high-density extracellular probes with ground-truth validation and drift correction https://github.com/JaneliaSciComp/JRCLUST """ NAME = 'JRClust' VERSION = '0.1.3' ENVIRONMENT_VARIABLES = [ 'NUM_WORKERS', 'MKL_NUM_THREADS', 'NUMEXPR_NUM_THREADS', 'OMP_NUM_THREADS', 'TEMPDIR' ] ADDITIONAL_FILES = ['*.m', '*.prm'] CONTAINER = None recording_dir = mlpr.Input('Directory of recording', directory=True) channels = mlpr.IntegerListParameter( description='List of channels to use.', optional=True, default=[]) firings_out = mlpr.Output('Output firings file') detect_sign = mlpr.IntegerParameter( optional=True, default=-1, description= 'Use -1, 0, or 1, depending on the sign of the spikes in the recording' ) adjacency_radius = mlpr.FloatParameter(optional=True, default=50, description='') detect_threshold = mlpr.FloatParameter(optional=True, default=4.5, description='detection threshold') freq_min = mlpr.FloatParameter( optional=True, default=300, description='Use 0 for no bandpass filtering') freq_max = mlpr.FloatParameter( optional=True, default=3000, description='Use 0 for no bandpass filtering') merge_thresh = mlpr.FloatParameter( optional=True, default=0.98, description='Threshold for automated merging') pc_per_chan = mlpr.IntegerParameter( optional=True, default=1, description='Number of principal components per channel') # added in version 0.2.4 filter_type = mlpr.StringParameter( optional=True, default='bandpass', description='{none, bandpass, wiener, fftdiff, ndiff}') nDiffOrder = mlpr.FloatParameter(optional=True, default=2, description='') common_ref_type = mlpr.StringParameter(optional=True, default='none', description='{none, mean, median}') min_count = mlpr.IntegerParameter(optional=True, default=30, description='Minimum cluster size') fGpu = mlpr.IntegerParameter(optional=True, default=1, description='Use GPU if available') fParfor = mlpr.IntegerParameter(optional=True, default=0, description='Use parfor if available') feature_type = mlpr.StringParameter( optional=True, default='gpca', description='{gpca, pca, vpp, vmin, vminmax, cov, energy, xcov}') def run(self): tmpdir = _get_tmpdir('jrclust') try: recording = SFMdaRecordingExtractor(self.recording_dir) params = read_dataset_params(self.recording_dir) if len(self.channels) > 0: recording = se.SubRecordingExtractor( parent_recording=recording, channel_ids=self.channels) if not os.path.exists(tmpdir): os.mkdir(tmpdir) all_params = dict() for param0 in self.PARAMETERS: all_params[param0.name] = getattr(self, param0.name) sorting = jrclust_helper( recording=recording, tmpdir=tmpdir, params=params, **all_params, ) SFMdaSortingExtractor.write_sorting(sorting=sorting, save_path=self.firings_out) except: if os.path.exists(tmpdir): if not getattr(self, '_keep_temp_files', False): shutil.rmtree(tmpdir) raise if not getattr(self, '_keep_temp_files', False): shutil.rmtree(tmpdir)
class TridesclousOld(mlpr.Processor): """ tridesclous is one of the more convenient, fast and elegant spike sorters. Installation instruction >>> pip install https://github.com/tridesclous/tridesclous/archive/master.zip More information on tridesclous at: * https://github.com/tridesclous/tridesclous * https://tridesclous.readthedocs.io """ NAME = 'Tridesclous' VERSION = '0.1.1' # wrapper VERSION ADDITIONAL_FILES = [] ENVIRONMENT_VARIABLES = [ 'NUM_WORKERS', 'MKL_NUM_THREADS', 'NUMEXPR_NUM_THREADS', 'OMP_NUM_THREADS', 'TEMPDIR' ] CONTAINER = 'sha1://9fb4a9350492ee84c8ea5d8692434ecba3cf33da/2019-05-13/tridesclous.simg' LOCAL_MODULES = ['../../spikeforest'] recording_dir = mlpr.Input('Directory of recording', directory=True) firings_out = mlpr.Output('Output firings file') channels = mlpr.IntegerListParameter( optional=True, default=[], description='List of channels to use.') detect_sign = mlpr.FloatParameter(optional=True, default=-1, description='') detection_threshold = mlpr.FloatParameter(optional=True, default=5.5, description='') freq_min = mlpr.FloatParameter(optional=True, default=400, description='') freq_max = mlpr.FloatParameter(optional=True, default=5000, description='') waveforms_n_left = mlpr.IntegerParameter(description='', optional=True, default=-45) waveforms_n_right = mlpr.IntegerParameter(description='', optional=True, default=60) align_waveform = mlpr.BoolParameter(description='', optional=True, default=False) common_ref_removal = mlpr.BoolParameter(description='', optional=True, default=False) peak_span = mlpr.FloatParameter(optional=True, default=.0002, description='') alien_value_threshold = mlpr.FloatParameter(optional=True, default=100, description='') def run(self): import tridesclous as tdc tmpdir = Path(_get_tmpdir('tdc')) recording = SFMdaRecordingExtractor(self.recording_dir) params = { 'fullchain_kargs': { 'duration': 300., 'preprocessor': { 'highpass_freq': self.freq_min, 'lowpass_freq': self.freq_max, 'smooth_size': 0, 'chunksize': 1024, 'lostfront_chunksize': 128, 'signalpreprocessor_engine': 'numpy', 'common_ref_removal': self.common_ref_removal, }, 'peak_detector': { 'peakdetector_engine': 'numpy', 'peak_sign': '-', 'relative_threshold': self.detection_threshold, 'peak_span': self.peak_span, }, 'noise_snippet': { 'nb_snippet': 300, }, 'extract_waveforms': { 'n_left': self.waveforms_n_left, 'n_right': self.waveforms_n_right, 'mode': 'rand', 'nb_max': 20000, 'align_waveform': self.align_waveform, }, 'clean_waveforms': { 'alien_value_threshold': self.alien_value_threshold, }, }, 'feat_method': 'peak_max', 'feat_kargs': {}, 'clust_method': 'sawchaincut', 'clust_kargs': { 'kde_bandwith': 1. }, } # save prb file: probe_file = tmpdir / 'probe.prb' se.save_probe_file(recording, probe_file, format='spyking_circus') # source file if isinstance(recording, se.BinDatRecordingExtractor) and recording._frame_first: # no need to copy raw_filename = recording._datfile dtype = recording._timeseries.dtype.str nb_chan = len(recording._channels) offset = recording._timeseries.offset else: # save binary file (chunk by hcunk) into a new file raw_filename = tmpdir / 'raw_signals.raw' n_chan = recording.get_num_channels() chunksize = 2**24 // n_chan se.write_binary_dat_format(recording, raw_filename, time_axis=0, dtype='float32', chunksize=chunksize) dtype = 'float32' offset = 0 # initialize source and probe file tdc_dataio = tdc.DataIO(dirname=str(tmpdir)) nb_chan = recording.get_num_channels() tdc_dataio.set_data_source( type='RawData', filenames=[str(raw_filename)], dtype=dtype, sample_rate=recording.get_sampling_frequency(), total_channel=nb_chan, offset=offset) tdc_dataio.set_probe_file(str(probe_file)) try: sorting = tdc_helper(tmpdir=tmpdir, params=params, recording=recording) SFMdaSortingExtractor.write_sorting(sorting=sorting, save_path=self.firings_out) except: if os.path.exists(tmpdir): if not getattr(self, '_keep_temp_files', False): shutil.rmtree(tmpdir) raise if not getattr(self, '_keep_temp_files', False): shutil.rmtree(tmpdir)
class KiloSort2Old(mlpr.Processor): """ KiloSort2 wrapper for SpikeForest framework written by J. James Jun, May 7, 2019 modified from `spiketoolkit/sorters/Kilosort` to be made compatible with SpikeForest [Prerequisite] 1. MATLAB (Tested on R2018b) 2. CUDA Toolkit v9.1 [Installation instruction in SpikeForest environment] 1. Run `git clone https://github.com/alexmorley/Kilosort2.git` Kilosort2 currently doesn't work on tetrodes and low-channel count probes (as of May 7, 2019). Clone from Alex Morley's repository that fixed these issues. Original Kilosort2 code can be obtained from `https://github.com/MouseLand/Kilosort2.git` 2. (optional) If Alex Morley's latest version doesn't work with SpikeForest, run `git checkout 43cbbfff89b9c88cdeb147ffd4ac35bfde9c7956` 3. In Matlab, run `CUDA/mexGPUall` to compile all CUDA codes 4. Add `KILOSORT2_PATH=...` in your .bashrc file. """ NAME = 'KiloSort2' VERSION = '0.3.3' # wrapper VERSION ADDITIONAL_FILES = ['*.m'] ENVIRONMENT_VARIABLES = [ 'NUM_WORKERS', 'MKL_NUM_THREADS', 'NUMEXPR_NUM_THREADS', 'OMP_NUM_THREADS' ] CONTAINER = None recording_dir = mlpr.Input('Directory of recording', directory=True) channels = mlpr.IntegerListParameter( description='List of channels to use.', optional=True, default=[]) firings_out = mlpr.Output('Output firings file') detect_sign = mlpr.IntegerParameter( default=-1, optional=True, description= 'Use -1 or 1, depending on the sign of the spikes in the recording') adjacency_radius = mlpr.FloatParameter( default=30, optional=True, description='Use -1 to include all channels in every neighborhood') detect_threshold = mlpr.FloatParameter(optional=True, default=6, description='') # prm_template_name=mlpr.StringParameter(optional=False,description='TODO') freq_min = mlpr.FloatParameter( optional=True, default=150, description='Use 0 for no bandpass filtering') freq_max = mlpr.FloatParameter( optional=True, default=6000, description='Use 0 for no bandpass filtering') merge_thresh = mlpr.FloatParameter(optional=True, default=0.98, description='TODO') pc_per_chan = mlpr.IntegerParameter(optional=True, default=3, description='TODO') minFR = mlpr.FloatParameter( default=1 / 50, optional=True, description= 'minimum spike rate (Hz), if a cluster falls below this for too long it gets removed' ) @staticmethod def install(): print('Auto-installing kilosort.') return install_kilosort2( repo='https://github.com/alexmorley/Kilosort2', commit='43cbbfff89b9c88cdeb147ffd4ac35bfde9c7956') def run(self): _keep_temp_files = True code = ''.join( random.choice(string.ascii_uppercase) for x in range(10)) tmpdir = os.environ.get('TEMPDIR', '/tmp') + '/kilosort2-tmp-' + code try: recording = SFMdaRecordingExtractor(self.recording_dir) if len(self.channels) > 0: recording = se.SubRecordingExtractor( parent_recording=recording, channel_ids=self.channels) if not os.path.exists(tmpdir): os.mkdir(tmpdir) sorting = kilosort2_helper(recording=recording, tmpdir=tmpdir, detect_sign=self.detect_sign, adjacency_radius=self.adjacency_radius, detect_threshold=self.detect_threshold, merge_thresh=self.merge_thresh, freq_min=self.freq_min, freq_max=self.freq_max, pc_per_chan=self.pc_per_chan, minFR=self.minFR) SFMdaSortingExtractor.write_sorting(sorting=sorting, save_path=self.firings_out) except: if os.path.exists(tmpdir): if not _keep_temp_files: print('removing tmpdir1') shutil.rmtree(tmpdir) raise if not _keep_temp_files: print('removing tmpdir2') shutil.rmtree(tmpdir)
class IronClust(mlpr.Processor): NAME = 'IronClust' VERSION = '0.7.9' ENVIRONMENT_VARIABLES = [ 'NUM_WORKERS', 'MKL_NUM_THREADS', 'NUMEXPR_NUM_THREADS', 'OMP_NUM_THREADS', 'TEMPDIR'] CONTAINER: Union[str, None] = None recording_dir = mlpr.Input('Directory of recording', directory=True) firings_out = mlpr.Output('Output firings file') channels = mlpr.IntegerListParameter( optional=True, default=[], description='List of channels to use.' ) detect_sign = mlpr.IntegerParameter( optional=True, default=-1, description='Use -1, 0, or 1, depending on the sign of the spikes in the recording' ) adjacency_radius = mlpr.FloatParameter( optional=True, default=50, description='Use -1 to include all channels in every neighborhood' ) adjacency_radius_out = mlpr.FloatParameter( optional=True, default=100, description='Use -1 to include all channels in every neighborhood' ) detect_threshold = mlpr.FloatParameter( optional=True, default=4, description='detection threshold' ) prm_template_name = mlpr.StringParameter( optional=True, default='', description='.prm template file name' ) freq_min = mlpr.FloatParameter( optional=True, default=300, description='Use 0 for no bandpass filtering' ) freq_max = mlpr.FloatParameter( optional=True, default=8000, description='Use 0 for no bandpass filtering' ) merge_thresh = mlpr.FloatParameter( optional=True, default=0.98, description='Threshold for automated merging' ) pc_per_chan = mlpr.IntegerParameter( optional=True, default=0, description='Number of principal components per channel' ) # added in version 0.2.4 whiten = mlpr.BoolParameter( optional=True, default=False, description='Whether to do channel whitening as part of preprocessing') filter_type = mlpr.StringParameter( optional=True, default='bandpass', description='{none, bandpass, wiener, fftdiff, ndiff}') filter_detect_type = mlpr.StringParameter( optional=True, default='none', description='{none, bandpass, wiener, fftdiff, ndiff}') common_ref_type = mlpr.StringParameter( optional=True, default='mean', description='{none, mean, median}') batch_sec_drift = mlpr.FloatParameter( optional=True, default=300, description='batch duration in seconds. clustering time duration') step_sec_drift = mlpr.FloatParameter( optional=True, default=20, description='compute anatomical similarity every n sec') knn = mlpr.IntegerParameter( optional=True, default=30, description='K nearest neighbors') min_count = mlpr.IntegerParameter( optional=True, default=30, description='Minimum cluster size') fGpu = mlpr.BoolParameter( optional=True, default=True, description='Use GPU if available') fft_thresh = mlpr.FloatParameter( optional=True, default=8, description='FFT-based noise peak threshold') fft_thresh_low = mlpr.FloatParameter( optional=True, default=0, description='FFT-based noise peak lower threshold (set to 0 to disable dual thresholding scheme)') nSites_whiten = mlpr.IntegerParameter( optional=True, default=32, description='Number of adjacent channels to whiten') feature_type = mlpr.StringParameter( optional=True, default='gpca', description='{gpca, pca, vpp, vmin, vminmax, cov, energy, xcov}') delta_cut = mlpr.FloatParameter( optional=True, default=1, description='Cluster detection threshold (delta-cutoff)') post_merge_mode = mlpr.IntegerParameter( optional=True, default=1, description='post_merge_mode') sort_mode = mlpr.IntegerParameter( optional=True, default=1, description='sort_mode') @staticmethod def install(): print('Auto-installing ironclust.') return install_ironclust(commit='72bb6d097e0875d6cfe52bddf5f782e667e1b042') def run(self): timer = time.time() import spikesorters as sorters print('IronClust......') try: ironclust_path = IronClust.install() except: traceback.print_exc() raise Exception('Problem installing ironclust.') sorters.IronClustSorter.set_ironclust_path(ironclust_path) recording = SFMdaRecordingExtractor(self.recording_dir) code = ''.join(random.choice(string.ascii_uppercase) for x in range(10)) tmpdir = os.environ.get('TEMPDIR', '/tmp') + '/ironclust-' + code sorter = sorters.IronClustSorter( recording=recording, output_folder=tmpdir, debug=True, delete_output_folder = False # will be taken care by _keep_temp_files one step above ) sorter.set_params( detect_sign=self.detect_sign, adjacency_radius=self.adjacency_radius, adjacency_radius_out=self.adjacency_radius_out, detect_threshold=self.detect_threshold, prm_template_name=self.prm_template_name, freq_min=self.freq_min, freq_max=self.freq_max, merge_thresh=self.merge_thresh, pc_per_chan=self.pc_per_chan, whiten=self.whiten, filter_type=self.filter_type, filter_detect_type=self.filter_detect_type, common_ref_type=self.common_ref_type, batch_sec_drift=self.batch_sec_drift, step_sec_drift=self.step_sec_drift, knn=self.knn, min_count=self.min_count, fGpu=self.fGpu, fft_thresh=self.fft_thresh, fft_thresh_low=self.fft_thresh_low, nSites_whiten=self.nSites_whiten, feature_type=self.feature_type, delta_cut=self.delta_cut, post_merge_mode=self.post_merge_mode, sort_mode=self.sort_mode ) timer = sorter.run() print('#SF-SORTER-RUNTIME#{:.3f}#'.format(timer)) sorting = sorter.get_result() SFMdaSortingExtractor.write_sorting( sorting=sorting, save_path=self.firings_out)
class SpykingCircus(mlpr.Processor): NAME = 'SpykingCircus' VERSION = '0.2.2' ENVIRONMENT_VARIABLES = [ 'NUM_WORKERS', 'MKL_NUM_THREADS', 'NUMEXPR_NUM_THREADS', 'OMP_NUM_THREADS' ] ADDITIONAL_FILES = ['*.params'] CONTAINER = 'sha1://914becce45aec56a84dd1dd4bca4037b09c50373/02-12-2019/spyking_circus.simg' CONTAINER_SHARE_ID = '69432e9201d0' # place to look for container recording_dir = mlpr.Input('Directory of recording', directory=True) channels = mlpr.IntegerListParameter( description='List of channels to use.', optional=True, default=[]) firings_out = mlpr.Output('Output firings file') detect_sign = mlpr.IntegerParameter(description='-1, 1, or 0') adjacency_radius = mlpr.FloatParameter( optional=True, default=100, description= 'Channel neighborhood adjacency radius corresponding to geom file') spike_thresh = mlpr.FloatParameter(optional=True, default=6, description='Threshold for detection') template_width_ms = mlpr.FloatParameter( optional=True, default=3, description='Spyking circus parameter') filter = mlpr.BoolParameter(optional=True, default=True) whitening_max_elts = mlpr.IntegerParameter( optional=True, default=1000, description= 'I believe it relates to subsampling and affects compute time') clustering_max_elts = mlpr.IntegerParameter( optional=True, default=10000, description= 'I believe it relates to subsampling and affects compute time') def run(self): code = ''.join( random.choice(string.ascii_uppercase) for x in range(10)) tmpdir = os.environ.get('TEMPDIR', '/tmp') + '/spyking-circus-tmp-' + code num_workers = os.environ.get('NUM_WORKERS', 1) try: recording = se.MdaRecordingExtractor(self.recording_dir) if len(self.channels) > 0: recording = se.SubRecordingExtractor( parent_recording=recording, channel_ids=self.channels) if not os.path.exists(tmpdir): os.mkdir(tmpdir) sorting = spyking_circus( recording=recording, output_folder=tmpdir, probe_file=None, file_name=None, detect_sign=self.detect_sign, adjacency_radius=self.adjacency_radius, spike_thresh=self.spike_thresh, template_width_ms=self.template_width_ms, filter=self.filter, merge_spikes=True, n_cores=num_workers, electrode_dimensions=None, whitening_max_elts=self.whitening_max_elts, clustering_max_elts=self.clustering_max_elts, ) se.MdaSortingExtractor.writeSorting(sorting=sorting, save_path=self.firings_out) except: if not getattr(self, '_keep_temp_files', False): if os.path.exists(tmpdir): shutil.rmtree(tmpdir) raise if not getattr(self, '_keep_temp_files', False): shutil.rmtree(tmpdir)
class KiloSortOld(mlpr.Processor): """ Kilosort wrapper for SpikeForest framework written by J. James Jun, May 21, 2019 [Prerequisite] 1. MATLAB (Tested on R2019a) 2. CUDA Toolkit v10.0 [Optional: Installation instruction in SpikeForest environment] 1. Run `git clone https://github.com/cortex-lab/KiloSort.git` 3. In Matlab, run `CUDA/mexGPUall` to compile all CUDA codes 4. Add `KILOSORT_PATH_DEV=...` in your .bashrc file. """ NAME = 'KiloSort' VERSION = '0.2.4' # wrapper VERSION ADDITIONAL_FILES = ['*.m'] ENVIRONMENT_VARIABLES = [ 'NUM_WORKERS', 'MKL_NUM_THREADS', 'NUMEXPR_NUM_THREADS', 'OMP_NUM_THREADS' ] CONTAINER = None LOCAL_MODULES = ['../../spikeforest'] recording_dir = mlpr.Input('Directory of recording', directory=True) channels = mlpr.IntegerListParameter( description='List of channels to use.', optional=True, default=[]) firings_out = mlpr.Output('Output firings file') detect_sign = mlpr.IntegerParameter( optional=True, default=-1, description= 'Use -1 or 1, depending on the sign of the spikes in the recording') adjacency_radius = mlpr.FloatParameter(optional=True, default=-1, description='Currently unused') detect_threshold = mlpr.FloatParameter(optional=True, default=3, description='') # prm_template_name=mlpr.StringParameter(optional=False,description='TODO') freq_min = mlpr.FloatParameter( optional=True, default=300, description='Use 0 for no bandpass filtering') freq_max = mlpr.FloatParameter( optional=True, default=6000, description='Use 0 for no bandpass filtering') merge_thresh = mlpr.FloatParameter(optional=True, default=0.98, description='TODO') pc_per_chan = mlpr.IntegerParameter(optional=True, default=3, description='TODO') @staticmethod def install(): print('Auto-installing kilosort.') return install_kilosort( repo='https://github.com/cortex-lab/KiloSort.git', commit='3f33771f8fdf8c3846a7f8a75cc8c318b44ed48c') def run(self): keep_temp_files = False code = ''.join( random.choice(string.ascii_uppercase) for x in range(10)) tmpdir = os.environ.get('TEMPDIR', '/tmp') + '/kilosort-tmp-' + code try: recording = SFMdaRecordingExtractor(self.recording_dir) if len(self.channels) > 0: recording = se.SubRecordingExtractor( parent_recording=recording, channel_ids=self.channels) if not os.path.exists(tmpdir): os.mkdir(tmpdir) sorting = kilosort_helper(recording=recording, tmpdir=tmpdir, detect_sign=self.detect_sign, adjacency_radius=self.adjacency_radius, detect_threshold=self.detect_threshold, merge_thresh=self.merge_thresh, freq_min=self.freq_min, freq_max=self.freq_max, pc_per_chan=self.pc_per_chan) SFMdaSortingExtractor.write_sorting(sorting=sorting, save_path=self.firings_out) except: if os.path.exists(tmpdir): if not keep_temp_files: shutil.rmtree(tmpdir) raise if not getattr(self, '_keep_temp_files', False): shutil.rmtree(tmpdir)