class KiloSort(mlpr.Processor): NAME = 'KiloSort' VERSION = '0.2.0' # wrapper VERSION ADDITIONAL_FILES = ['*.m'] ENVIRONMENT_VARIABLES = [ 'NUM_WORKERS', 'MKL_NUM_THREADS', 'NUMEXPR_NUM_THREADS', 'OMP_NUM_THREADS'] CONTAINER = None CONTAINER_SHARE_ID = None recording_dir = mlpr.Input('Directory of recording', directory=True) channels = mlpr.IntegerListParameter( description='List of channels to use.', optional=True, default=[]) firings_out = mlpr.Output('Output firings file') detect_sign = mlpr.IntegerParameter( 'Use -1 or 1, depending on the sign of the spikes in the recording') adjacency_radius = mlpr.FloatParameter( 'Use -1 to include all channels in every neighborhood') detect_threshold = mlpr.FloatParameter( optional=True, default=3, description='') # prm_template_name=mlpr.StringParameter(optional=False,description='TODO') freq_min = mlpr.FloatParameter( optional=True, default=300, description='Use 0 for no bandpass filtering') freq_max = mlpr.FloatParameter( optional=True, default=6000, description='Use 0 for no bandpass filtering') merge_thresh = mlpr.FloatParameter( optional=True, default=0.98, description='TODO') pc_per_chan = mlpr.IntegerParameter( optional=True, default=3, description='TODO') def run(self): code = ''.join(random.choice(string.ascii_uppercase) for x in range(10)) tmpdir = os.environ.get('TEMPDIR', '/tmp')+'/kilosort-tmp-'+code try: recording = se.MdaRecordingExtractor(self.recording_dir) if len(self.channels) > 0: recording = se.SubRecordingExtractor( parent_recording=recording, channel_ids=self.channels) if not os.path.exists(tmpdir): os.mkdir(tmpdir) sorting = kilosort_helper( recording=recording, tmpdir=tmpdir, detect_sign=self.detect_sign, adjacency_radius=self.adjacency_radius, detect_threshold=self.detect_threshold, merge_thresh=self.merge_thresh, freq_min=self.freq_min, freq_max=self.freq_max, pc_per_chan=self.pc_per_chan ) se.MdaSortingExtractor.writeSorting( sorting=sorting, save_path=self.firings_out) except: if os.path.exists(tmpdir): shutil.rmtree(tmpdir) raise shutil.rmtree(tmpdir)
class MountainSort4(mlpr.Processor): NAME = 'MountainSort4' VERSION = '4.0.1' dataset_dir = mlpr.Input('Directory of dataset', directory=True) firings_out = mlpr.Output('Output firings file') detect_sign = mlpr.IntegerParameter( 'Use -1, 0, or 1, depending on the sign of the spikes in the recording' ) adjacency_radius = mlpr.FloatParameter( 'Use -1 to include all channels in every neighborhood') freq_min = mlpr.FloatParameter( optional=True, default=300, description='Use 0 for no bandpass filtering') freq_max = mlpr.FloatParameter( optional=True, default=6000, description='Use 0 for no bandpass filtering') whiten = mlpr.BoolParameter( optional=True, default=True, description='Whether to do channel whitening as part of preprocessing') clip_size = mlpr.IntegerParameter(optional=True, default=50, description='') detect_threshold = mlpr.FloatParameter(optional=True, default=3, description='') detect_interval = mlpr.IntegerParameter( optional=True, default=10, description= 'Minimum number of timepoints between events detected on the same channel' ) noise_overlap_threshold = mlpr.FloatParameter( optional=True, default=0.15, description='Use None for no automated curation') def run(self): recording = si.MdaRecordingExtractor(self.dataset_dir) num_workers = int(os.environ.get('NUM_WORKERS', -1)) if num_workers <= 0: num_workers = None sorting = sf.sorters.mountainsort4( recording=recording, detect_sign=self.detect_sign, adjacency_radius=self.adjacency_radius, freq_min=self.freq_min, freq_max=self.freq_max, whiten=self.whiten, clip_size=self.clip_size, detect_threshold=self.detect_threshold, detect_interval=self.detect_interval, noise_overlap_threshold=self.noise_overlap_threshold, num_workers=num_workers) si.MdaSortingExtractor.writeSorting(sorting=sorting, save_path=self.firings_out)
class GenerateMearecRecording(mlpr.Processor): NAME = "GenerateMearecRecording" VERSION = "0.1.0" # input file templates_in = mlpr.Input(description='.h5 file containing templates') # output file recording_out = mlpr.Output() # recordings params drifting = mlpr.BoolParameter() noise_level = mlpr.FloatParameter() bursting = mlpr.BoolParameter() shape_mod = mlpr.BoolParameter() # spiketrains params duration = mlpr.FloatParameter() n_exc = mlpr.IntegerParameter() n_inh = mlpr.IntegerParameter() # templates params min_dist = mlpr.FloatParameter() # seed seed = mlpr.IntegerParameter() def run(self): recordings_params = deepcopy(mr.get_default_recordings_params()) recordings_params['recordings']['drifting'] = self.drifting recordings_params['recordings']['noise_level'] = self.noise_level recordings_params['recordings']['bursting'] = self.bursting recordings_params['recordings']['shape_mod'] = self.shape_mod recordings_params['recordings']['seed'] = self.seed # recordings_params['recordings']['chunk_conv_duration'] = 0 # turn off parallel execution recordings_params['spiketrains']['duration'] = self.duration recordings_params['spiketrains']['n_exc'] = self.n_exc recordings_params['spiketrains']['n_inh'] = self.n_inh recordings_params['spiketrains']['seed'] = self.seed recordings_params['templates']['min_dist'] = self.min_dist recordings_params['templates']['seed'] = self.seed # this is needed because mr.load_templates requires the file extension templates_fname = self.templates_in + '.h5' shutil.copyfile(self.templates_in, templates_fname) tempgen = mr.load_templates(Path(templates_fname)) recgen = mr.gen_recordings(params=recordings_params, tempgen=tempgen, verbose=False) mr.save_recording_generator(recgen, self.recording_out) del recgen
class SpykingCircus(mlpr.Processor): NAME='SpykingCircus' VERSION='0.1.2' recording_dir=mlpr.Input('Directory of recording',directory=True) channels=mlpr.IntegerListParameter(description='List of channels to use.',optional=True,default=[]) firings_out=mlpr.Output('Output firings file') detect_sign=mlpr.IntegerParameter(description='-1, 1, or 0') adjacency_radius=mlpr.FloatParameter(optional=True,default=100,description='Channel neighborhood adjacency radius corresponding to geom file') spike_thresh=mlpr.FloatParameter(optional=True,default=6,description='Threshold for detection') template_width_ms=mlpr.FloatParameter(optional=True,default=3,description='Spyking circus parameter') filter=mlpr.BoolParameter(optional=True,default=True) whitening_max_elts=mlpr.IntegerParameter(optional=True,default=1000,description='I believe it relates to subsampling and affects compute time') clustering_max_elts=mlpr.IntegerParameter(optional=True,default=10000,description='I believe it relates to subsampling and affects compute time') def run(self): code=''.join(random.choice(string.ascii_uppercase) for x in range(10)) tmpdir=os.environ.get('TEMPDIR','/tmp')+'/ironclust-tmp-'+code num_workers=os.environ.get('NUM_WORKERS',2) try: recording=si.MdaRecordingExtractor(self.recording_dir) if len(self.channels)>0: recording=si.SubRecordingExtractor(parent_recording=recording,channel_ids=self.channels) if not os.path.exists(tmpdir): os.mkdir(tmpdir) sorting=sf.sorters.spyking_circus( recording=recording, output_folder=tmpdir, probe_file=None, file_name=None, detect_sign=self.detect_sign, adjacency_radius=self.adjacency_radius, spike_thresh=self.spike_thresh, template_width_ms=self.template_width_ms, filter=self.filter, merge_spikes=True, n_cores=num_workers, electrode_dimensions=None, whitening_max_elts=self.whitening_max_elts, clustering_max_elts=self.clustering_max_elts ) si.MdaSortingExtractor.writeSorting(sorting=sorting,save_path=self.firings_out) except: if os.path.exists(tmpdir): shutil.rmtree(tmpdir) raise shutil.rmtree(tmpdir)
class IronClust(mlpr.Processor): NAME='IronClust' VERSION='4.2.6' recording_dir=mlpr.Input('Directory of recording',directory=True) channels=mlpr.IntegerListParameter(description='List of channels to use.',optional=True,default=[]) firings_out=mlpr.Output('Output firings file') detect_sign=mlpr.IntegerParameter('Use -1, 0, or 1, depending on the sign of the spikes in the recording') adjacency_radius=mlpr.FloatParameter('Use -1 to include all channels in every neighborhood') detect_threshold=mlpr.FloatParameter(optional=True,default=3,description='') prm_template_name=mlpr.StringParameter(optional=False,description='TODO') freq_min=mlpr.FloatParameter(optional=True,default=300,description='Use 0 for no bandpass filtering') freq_max=mlpr.FloatParameter(optional=True,default=6000,description='Use 0 for no bandpass filtering') merge_thresh=mlpr.FloatParameter(optional=True,default=0.98,description='TODO') pc_per_chan=mlpr.IntegerParameter(optional=True,default=3,description='TODO') def run(self): ironclust_src=os.environ.get('IRONCLUST_SRC',None) if not ironclust_src: raise Exception('Environment variable not set: IRONCLUST_SRC') code=''.join(random.choice(string.ascii_uppercase) for x in range(10)) tmpdir=os.environ.get('TEMPDIR','/tmp')+'/ironclust-tmp-'+code try: recording=si.MdaRecordingExtractor(self.recording_dir) if len(self.channels)>0: recording=si.SubRecordingExtractor(parent_recording=recording,channel_ids=self.channels) if not os.path.exists(tmpdir): os.mkdir(tmpdir) sorting=sf.sorters.ironclust( recording=recording, tmpdir=tmpdir, ## TODO detect_sign=self.detect_sign, adjacency_radius=self.adjacency_radius, detect_threshold=self.detect_threshold, merge_thresh=self.merge_thresh, freq_min=self.freq_min, freq_max=self.freq_max, pc_per_chan=self.pc_per_chan, prm_template_name=self.prm_template_name, ironclust_src=ironclust_src ) si.MdaSortingExtractor.writeSorting(sorting=sorting,save_path=self.firings_out) except: if os.path.exists(tmpdir): shutil.rmtree(tmpdir) raise shutil.rmtree(tmpdir)
class ComputeMandelbrotWithError(mlpr.Processor): NAME = 'ComputeMandelbrotWithError' VERSION = '0.1.4' xmin = mlpr.IntegerParameter('The minimum x value', optional=True, default=-2) xmax = mlpr.IntegerParameter('The maximum x value', optional=True, default=0.5) ymin = mlpr.IntegerParameter('The minimum y value', optional=True, default=-1.25) ymax = mlpr.IntegerParameter('The maximum y value', optional=True, default=1.25) num_x = mlpr.IntegerParameter( 'The number of points (resolution) in the x dimension', optional=True, default=1000) num_iter = mlpr.IntegerParameter('Number of iterations', optional=True, default=1000) subsampling_factor = mlpr.IntegerParameter( 'Subsampling factor (1 means no subsampling)', optional=True, default=1) subsampling_offset = mlpr.IntegerParameter('Subsampling offset', optional=True, default=0) throw_error = mlpr.BoolParameter( 'Whether to intentionally throw an error for testing purposes', optional=True, default=False) output_npy = mlpr.Output('The output .npy file.') def __init__(self): mlpr.Processor.__init__(self) def run(self): import time if self.throw_error: print('Intentionally throwing error in 2 seconds...') time.sleep(2) raise Exception('Intentionally throwing error.') if self.subsampling_factor > 1: print('Using subsampling factor {}, offset {}'.format( self.subsampling_factor, self.subsampling_offset)) X = compute_mandelbrot(xmin=self.xmin, xmax=self.xmax, ymin=self.ymin, ymax=self.ymax, num_x=self.num_x, num_iter=self.num_iter, subsampling_factor=self.subsampling_factor, subsampling_offset=self.subsampling_offset) np.save(self.output_npy, X)
class ComputeUnitDetail(mlpr.Processor): NAME = 'ComputeUnitDetail' VERSION = '0.1.0' CONTAINER = None recording_dir = mlpr.Input(description='Recording directory', optional=False, directory=True) firings = mlpr.Input(description='Input firings.mda file') unit_id = mlpr.IntegerParameter(description='Unit ID') json_out = mlpr.Output(description='Output .json file') def run(self): recording = SFMdaRecordingExtractor( dataset_directory=self.recording_directory, download=True) sorting = SFMdaSortingExtractor(firings_file=self.firings) waveforms0 = _get_random_spike_waveforms(recording=recording, sorting=sorting, unit=self.unit_id) channel_ids = recording.get_channel_ids() avg_waveform = np.median(waveforms0, axis=2) ret = dict(channel_ids=channel_ids, average_waveform=avg_waveform.tolist()) with open(self.json_out, 'w') as f: json.dump(ret, f)
class CreateEfficientAccessRecordingFile(mlpr.Processor): NAME = 'CreateEfficientAccessRecordingFile' VERSION = '0.1.2' recording = mlpr.Input() segment_size = mlpr.IntegerParameter(optional=True, default=1000000) hdf5_out = mlpr.Output() def run(self): import h5py recording = self.recording segment_size = self.segment_size channel_ids = recording.get_channel_ids() samplerate = recording.get_sampling_frequency() M = len(channel_ids) # number of channels N = recording.get_num_frames() # Number of timepoints num_segments = int(np.ceil(N / segment_size)) try: channel_locations = recording.get_channel_locations( channel_ids=channel_ids) nd = len(channel_locations[0]) geom = np.zeros((M, nd)) for m in range(M): geom[m, :] = channel_locations[m] except: nd = 2 geom = np.zeros((M, nd)) with h5py.File(self.hdf5_out, "w") as f: f.create_dataset('segment_size', data=[segment_size]) f.create_dataset('num_segments', data=[num_segments]) f.create_dataset('num_channels', data=[M]) f.create_dataset('channel_ids', data=np.array(channel_ids)) f.create_dataset('num_timepoints', data=[N]) f.create_dataset('samplerate', data=[samplerate]) f.create_dataset('geom', data=geom) if callable(recording.hash): hash0 = recording.hash() else: hash0 = recording.hash f.create_dataset('recording_hash', data=np.array([hash0.encode()])) for j in range(num_segments): segment = np.zeros((M, segment_size), dtype=float) # fix dtype here t1 = int(j * segment_size) # first timepoint of the segment t2 = int(np.minimum( N, (t1 + segment_size))) # last timepoint of segment (+1) s1 = int(np.maximum(0, t1)) # first timepoint s2 = int(np.minimum(N, t2)) # last timepoint (+1) # determine aa so that t1-s1+aa = 0 # so, aa = -(t1-s1) aa = -(t1 - s1) segment[:, aa:aa + s2 - s1] = recording.get_traces( start_frame=s1, end_frame=s2) # Read the segment for ii, ch in enumerate(channel_ids): f.create_dataset('part-{}-{}'.format(ch, j), data=segment[ii, :].ravel())
class YASS(mlpr.Processor): NAME = 'YASS' VERSION = '0.1.0' # used by container to pass the env variables ENVIRONMENT_VARIABLES = [ 'NUM_WORKERS', 'MKL_NUM_THREADS', 'NUMEXPR_NUM_THREADS', 'OMP_NUM_THREADS'] ADDITIONAL_FILES = ['*.yaml'] CONTAINER = 'sha1://087767605e10761331699dda29519444bbd823f4/02-12-2019/yass.simg' CONTAINER_SHARE_ID = '69432e9201d0' # place to look for container recording_dir = mlpr.Input('Directory of recording', directory=True) channels = mlpr.IntegerListParameter( description='List of channels to use.', optional=True, default=[]) firings_out = mlpr.Output('Output firings file') #paramfile_out = mlpr.Output('YASS yaml config file') detect_sign = mlpr.IntegerParameter(description='-1, 1, or 0') adjacency_radius = mlpr.FloatParameter( optional=True, default=100, description='Channel neighborhood adjacency radius corresponding to geom file') template_width_ms = mlpr.FloatParameter( optional=True, default=3, description='Spike width in milliseconds') filter = mlpr.BoolParameter(optional=True, default=True) def run(self): code = ''.join(random.choice(string.ascii_uppercase) for x in range(10)) tmpdir = os.environ.get('TEMPDIR', '/tmp')+'/yass-tmp-'+code #num_workers = os.environ.get('NUM_WORKERS', 1) #print('num_workers: {}'.format(num_workers)) try: recording = se.MdaRecordingExtractor(self.recording_dir) if len(self.channels) > 0: recording = se.SubRecordingExtractor( parent_recording=recording, channel_ids=self.channels) if not os.path.exists(tmpdir): os.mkdir(tmpdir) sorting, yaml_file = yass_helper( recording=recording, output_folder=tmpdir, probe_file=None, file_name=None, detect_sign=self.detect_sign, adjacency_radius=self.adjacency_radius, template_width_ms=self.template_width_ms, filter=self.filter) se.MdaSortingExtractor.writeSorting( sorting=sorting, save_path=self.firings_out) #shutil.copyfile(yaml_file, self.paramfile_out) except: if os.path.exists(tmpdir): shutil.rmtree(tmpdir) raise if not getattr(self, '_keep_temp_files', False): shutil.rmtree(tmpdir)
class CreateSpikeSprays(mlpr.Processor): NAME = 'CreateSpikeSprays' VERSION = '0.1.0' CONTAINER = _CONTAINER recording_directory = mlpr.Input(description='Recording directory', optional=False, directory=True) filtered_timeseries = mlpr.Input( description='Filtered timeseries file (.mda)', optional=False) firings_true = mlpr.Input(description='True firings -- firings_true.mda', optional=False) firings_sorted = mlpr.Input(description='Sorted firings -- firings.mda', optional=False) unit_id_true = mlpr.IntegerParameter(description='ID of the true unit') unit_id_sorted = mlpr.IntegerParameter(description='ID of the sorted unit') neighborhood_size = mlpr.IntegerParameter( description='Max size of the electrode neighborhood', optional=True, default=7) num_spikes = mlpr.IntegerParameter( description='Max number of spikes in the spike spray', optional=True, default=20) json_out = mlpr.Output(description='Output json object') def run(self): rx = SFMdaRecordingExtractor( dataset_directory=self.recording_directory, download=True, raw_fname=self.filtered_timeseries) sx_true = SFMdaSortingExtractor(firings_file=self.firings_true) sx = SFMdaSortingExtractor(firings_file=self.firings_sorted) ssobj = create_spikesprays(rx=rx, sx_true=sx_true, sx_sorted=sx, neighborhood_size=self.neighborhood_size, num_spikes=self.num_spikes, unit_id_true=self.unit_id_true, unit_id_sorted=self.unit_id_sorted) with open(self.json_out, 'w') as f: json.dump(ssobj, f)
class ComputeMandelbrot(mlpr.Processor): NAME = 'ComputeMandelbrot' VERSION = '0.1.4' xmin = mlpr.IntegerParameter('The minimum x value', optional=True, default=-2) xmax = mlpr.IntegerParameter('The maximum x value', optional=True, default=0.5) ymin = mlpr.IntegerParameter('The minimum y value', optional=True, default=-1.25) ymax = mlpr.IntegerParameter('The maximum y value', optional=True, default=1.25) num_x = mlpr.IntegerParameter( 'The number of points (resolution) in the x dimension', optional=True, default=1000) num_iter = mlpr.IntegerParameter('Number of iterations', optional=True, default=1000) subsampling_factor = mlpr.IntegerParameter( 'Subsampling factor (1 means no subsampling)', optional=True, default=1) subsampling_offset = mlpr.IntegerParameter('Subsampling offset', optional=True, default=0) output_npy = mlpr.Output('The output .npy file.') def __init__(self): mlpr.Processor.__init__(self) def run(self): print('=== ComputeMandelbrot ===', self.subsampling_factor, self.subsampling_offset) if self.subsampling_factor > 1: print('Using subsampling factor {}, offset {}'.format( self.subsampling_factor, self.subsampling_offset)) X = compute_mandelbrot(xmin=self.xmin, xmax=self.xmax, ymin=self.ymin, ymax=self.ymax, num_x=self.num_x, num_iter=self.num_iter, subsampling_factor=self.subsampling_factor, subsampling_offset=self.subsampling_offset) np.save(self.output_npy, X)
class ComputeNPrimes(mlpr.Processor): NAME = 'ComputeNPrimes' VERSION = '0.1.3' n = mlpr.IntegerParameter('The integer n.') output = mlpr.Output('The output .npy file.') def __init__(self): mlpr.Processor.__init__(self) def run(self): primes = compute_n_primes(self.n) print('Prime {}: {}'.format(self.n, primes[-1])) np.save(self.output, primes)
class ComputeNthPrime(mlpr.Processor): NAME = 'ComputeNthPrime' VERSION = '0.1.1' n = mlpr.IntegerParameter('The integer n.') output = mlpr.Output('The output text file.') def __init__(self): mlpr.Processor.__init__(self) def run(self): prime = nth_prime_number(self.n) with open(self.output, 'w') as f: f.write('{}'.format(prime))
class CombineSubsampledMandelbrot(mlpr.Processor): NAME = 'CombineSubsampledMandelbrot' VERSION = '0.1.1' X_list = mlpr.Input(multi=True) X_out = mlpr.Output() num_x = mlpr.IntegerParameter() def run(self): print('=== CombineSubsampledMandelbrot ===', self.num_x) self.X_list arrays = [] for X0 in self.X_list: arrays.append(np.load(X0)) X = combine_subsampled_mandelbrot(arrays) X = X[:self.num_x, :] np.save(self.X_out, X)
class ComputeAutocorrelograms(mlpr.Processor): NAME = 'ComputeAutocorrelograms' VERSION = '0.1.4' # Inputs firings_path = mlpr.Input() # Parameters samplerate = mlpr.FloatParameter() max_samples = mlpr.IntegerParameter(optional=True, default=100000) bin_size_msec = mlpr.FloatParameter(optional=True, default=2) max_dt_msec = mlpr.FloatParameter(optional=True, default=50) # Outputs json_out = mlpr.Output() def run(self): from spikeforest import SFMdaRecordingExtractor, SFMdaSortingExtractor print('test1', self.firings_path, self.samplerate) sorting = SFMdaSortingExtractor(firings_file=self.firings_path) samplerate = self.samplerate max_samples = self.max_samples max_dt_msec = self.max_dt_msec bin_size_msec = self.bin_size_msec max_dt_tp = max_dt_msec * samplerate / 1000 bin_size_tp = bin_size_msec * samplerate / 1000 autocorrelograms = [] for unit_id in sorting.get_unit_ids(): print('Unit::g {}'.format(unit_id)) (bin_counts, bin_edges) = compute_autocorrelogram(sorting.get_unit_spike_train(unit_id), max_dt_tp=max_dt_tp, bin_size_tp=bin_size_tp, max_samples=max_samples) autocorrelograms.append(dict( unit_id=unit_id, bin_counts=bin_counts, bin_edges=bin_edges )) ret = dict( autocorrelograms=autocorrelograms ) with open(self.json_out, 'w') as f: json.dump(serialize_np(ret), f)
class ComputeUnitDetail(mlpr.Processor): NAME = 'ComputeUnitsDetail' VERSION = '0.1.5' recording = mlpr.Input() sorting = mlpr.Input() unit_id = mlpr.IntegerParameter() output = mlpr.Output() def run(self): event_times = self.sorting.get_unit_spike_train(unit_id=self.unit_id) # pylint: disable=no-member snippets = self.recording.get_snippets(reference_frames=event_times, snippet_len=100) # pylint: disable=no-member template = np.median(np.stack(snippets), axis=0) result0 = dict(unit_id=self.unit_id, num_events=len(event_times), event_times=event_times, snippets=snippets, template=template) with open(self.output, 'wb') as f: pickle.dump(result0, f)
class ComputeAutocorrelograms(mlpr.Processor): NAME = 'ComputeAutocorrelograms' VERSION = '0.1.5' # Inputs sorting = mlpr.Input() # Parameters max_samples = mlpr.IntegerParameter(optional=True, default=100000) bin_size_msec = mlpr.FloatParameter(optional=True, default=2) max_dt_msec = mlpr.FloatParameter(optional=True, default=50) # Outputs json_out = mlpr.Output() def run(self): sorting = self.sorting samplerate = sorting.get_sampling_frequency() max_samples = self.max_samples max_dt_msec = self.max_dt_msec bin_size_msec = self.bin_size_msec max_dt_tp = max_dt_msec * samplerate / 1000 bin_size_tp = bin_size_msec * samplerate / 1000 autocorrelograms = [] for unit_id in sorting.get_unit_ids(): (bin_counts, bin_edges) = compute_autocorrelogram( sorting.get_unit_spike_train(unit_id=unit_id), max_dt_tp=max_dt_tp, bin_size_tp=bin_size_tp, max_samples=max_samples) bin_edges = bin_edges / samplerate * 1000 # milliseconds autocorrelograms.append( dict(unit_id=unit_id, bin_counts=bin_counts, bin_edges=bin_edges)) ret = dict(autocorrelograms=autocorrelograms) with open(self.json_out, 'w') as f: simplejson.dump(serialize_np(ret), f, ignore_nan=True)
class TridesclousOld(mlpr.Processor): """ tridesclous is one of the more convenient, fast and elegant spike sorters. Installation instruction >>> pip install https://github.com/tridesclous/tridesclous/archive/master.zip More information on tridesclous at: * https://github.com/tridesclous/tridesclous * https://tridesclous.readthedocs.io """ NAME = 'Tridesclous' VERSION = '0.1.1' # wrapper VERSION ADDITIONAL_FILES = [] ENVIRONMENT_VARIABLES = [ 'NUM_WORKERS', 'MKL_NUM_THREADS', 'NUMEXPR_NUM_THREADS', 'OMP_NUM_THREADS', 'TEMPDIR' ] CONTAINER = 'sha1://9fb4a9350492ee84c8ea5d8692434ecba3cf33da/2019-05-13/tridesclous.simg' LOCAL_MODULES = ['../../spikeforest'] recording_dir = mlpr.Input('Directory of recording', directory=True) firings_out = mlpr.Output('Output firings file') channels = mlpr.IntegerListParameter( optional=True, default=[], description='List of channels to use.') detect_sign = mlpr.FloatParameter(optional=True, default=-1, description='') detection_threshold = mlpr.FloatParameter(optional=True, default=5.5, description='') freq_min = mlpr.FloatParameter(optional=True, default=400, description='') freq_max = mlpr.FloatParameter(optional=True, default=5000, description='') waveforms_n_left = mlpr.IntegerParameter(description='', optional=True, default=-45) waveforms_n_right = mlpr.IntegerParameter(description='', optional=True, default=60) align_waveform = mlpr.BoolParameter(description='', optional=True, default=False) common_ref_removal = mlpr.BoolParameter(description='', optional=True, default=False) peak_span = mlpr.FloatParameter(optional=True, default=.0002, description='') alien_value_threshold = mlpr.FloatParameter(optional=True, default=100, description='') def run(self): import tridesclous as tdc tmpdir = Path(_get_tmpdir('tdc')) recording = SFMdaRecordingExtractor(self.recording_dir) params = { 'fullchain_kargs': { 'duration': 300., 'preprocessor': { 'highpass_freq': self.freq_min, 'lowpass_freq': self.freq_max, 'smooth_size': 0, 'chunksize': 1024, 'lostfront_chunksize': 128, 'signalpreprocessor_engine': 'numpy', 'common_ref_removal': self.common_ref_removal, }, 'peak_detector': { 'peakdetector_engine': 'numpy', 'peak_sign': '-', 'relative_threshold': self.detection_threshold, 'peak_span': self.peak_span, }, 'noise_snippet': { 'nb_snippet': 300, }, 'extract_waveforms': { 'n_left': self.waveforms_n_left, 'n_right': self.waveforms_n_right, 'mode': 'rand', 'nb_max': 20000, 'align_waveform': self.align_waveform, }, 'clean_waveforms': { 'alien_value_threshold': self.alien_value_threshold, }, }, 'feat_method': 'peak_max', 'feat_kargs': {}, 'clust_method': 'sawchaincut', 'clust_kargs': { 'kde_bandwith': 1. }, } # save prb file: probe_file = tmpdir / 'probe.prb' se.save_probe_file(recording, probe_file, format='spyking_circus') # source file if isinstance(recording, se.BinDatRecordingExtractor) and recording._frame_first: # no need to copy raw_filename = recording._datfile dtype = recording._timeseries.dtype.str nb_chan = len(recording._channels) offset = recording._timeseries.offset else: # save binary file (chunk by hcunk) into a new file raw_filename = tmpdir / 'raw_signals.raw' n_chan = recording.get_num_channels() chunksize = 2**24 // n_chan se.write_binary_dat_format(recording, raw_filename, time_axis=0, dtype='float32', chunksize=chunksize) dtype = 'float32' offset = 0 # initialize source and probe file tdc_dataio = tdc.DataIO(dirname=str(tmpdir)) nb_chan = recording.get_num_channels() tdc_dataio.set_data_source( type='RawData', filenames=[str(raw_filename)], dtype=dtype, sample_rate=recording.get_sampling_frequency(), total_channel=nb_chan, offset=offset) tdc_dataio.set_probe_file(str(probe_file)) try: sorting = tdc_helper(tmpdir=tmpdir, params=params, recording=recording) SFMdaSortingExtractor.write_sorting(sorting=sorting, save_path=self.firings_out) except: if os.path.exists(tmpdir): if not getattr(self, '_keep_temp_files', False): shutil.rmtree(tmpdir) raise if not getattr(self, '_keep_temp_files', False): shutil.rmtree(tmpdir)
class JRClust(mlpr.Processor): """ JRClust v4 wrapper written by J. James Jun, May 8, 2019 modified from `spikeforest/spikesorters/ironclust/ironclust.py` [Installation instruction in SpikeForest environment] 1. Run `git clone https://github.com/JaneliaSciComp/JRCLUST` 2. Activate conda environment for SpikeForest 3. Create `JRCLUST_PATH` and `MDAIO_PATH` 4. Flatiron execution: `module load matlab/R2019a cuda/10.0.130_410.48` DO NOT LOAD `module load gcc`. DO NOT USE `matlab/R2018b` (it will crash while creating a parpool). See: James Jun, et al. Real-time spike sorting platform for high-density extracellular probes with ground-truth validation and drift correction https://github.com/JaneliaSciComp/JRCLUST """ NAME = 'JRClust' VERSION = '0.1.3' ENVIRONMENT_VARIABLES = [ 'NUM_WORKERS', 'MKL_NUM_THREADS', 'NUMEXPR_NUM_THREADS', 'OMP_NUM_THREADS', 'TEMPDIR' ] ADDITIONAL_FILES = ['*.m', '*.prm'] CONTAINER = None recording_dir = mlpr.Input('Directory of recording', directory=True) channels = mlpr.IntegerListParameter( description='List of channels to use.', optional=True, default=[]) firings_out = mlpr.Output('Output firings file') detect_sign = mlpr.IntegerParameter( optional=True, default=-1, description= 'Use -1, 0, or 1, depending on the sign of the spikes in the recording' ) adjacency_radius = mlpr.FloatParameter(optional=True, default=50, description='') detect_threshold = mlpr.FloatParameter(optional=True, default=4.5, description='detection threshold') freq_min = mlpr.FloatParameter( optional=True, default=300, description='Use 0 for no bandpass filtering') freq_max = mlpr.FloatParameter( optional=True, default=3000, description='Use 0 for no bandpass filtering') merge_thresh = mlpr.FloatParameter( optional=True, default=0.98, description='Threshold for automated merging') pc_per_chan = mlpr.IntegerParameter( optional=True, default=1, description='Number of principal components per channel') # added in version 0.2.4 filter_type = mlpr.StringParameter( optional=True, default='bandpass', description='{none, bandpass, wiener, fftdiff, ndiff}') nDiffOrder = mlpr.FloatParameter(optional=True, default=2, description='') common_ref_type = mlpr.StringParameter(optional=True, default='none', description='{none, mean, median}') min_count = mlpr.IntegerParameter(optional=True, default=30, description='Minimum cluster size') fGpu = mlpr.IntegerParameter(optional=True, default=1, description='Use GPU if available') fParfor = mlpr.IntegerParameter(optional=True, default=0, description='Use parfor if available') feature_type = mlpr.StringParameter( optional=True, default='gpca', description='{gpca, pca, vpp, vmin, vminmax, cov, energy, xcov}') def run(self): tmpdir = _get_tmpdir('jrclust') try: recording = SFMdaRecordingExtractor(self.recording_dir) params = read_dataset_params(self.recording_dir) if len(self.channels) > 0: recording = se.SubRecordingExtractor( parent_recording=recording, channel_ids=self.channels) if not os.path.exists(tmpdir): os.mkdir(tmpdir) all_params = dict() for param0 in self.PARAMETERS: all_params[param0.name] = getattr(self, param0.name) sorting = jrclust_helper( recording=recording, tmpdir=tmpdir, params=params, **all_params, ) SFMdaSortingExtractor.write_sorting(sorting=sorting, save_path=self.firings_out) except: if os.path.exists(tmpdir): if not getattr(self, '_keep_temp_files', False): shutil.rmtree(tmpdir) raise if not getattr(self, '_keep_temp_files', False): shutil.rmtree(tmpdir)
class SpykingCircus(mlpr.Processor): NAME = 'SpykingCircus' VERSION = '0.3.4' ENVIRONMENT_VARIABLES = [ 'NUM_WORKERS', 'MKL_NUM_THREADS', 'NUMEXPR_NUM_THREADS', 'OMP_NUM_THREADS' ] ADDITIONAL_FILES = ['*.params'] # CONTAINER = 'sha1://8958530b960522d529163344af2faa09ea805716/2019-05-06/spyking_circus.simg' # CONTAINER = 'sha1://eed2314fbe2fb1cc7cfe0a36b4e205ffb94add1c/2019-06-17/spyking_circus.simg' # CONTAINER = 'sha1://68a175faef53e29af068b8b95649021593f9020a/2019-07-01/spyking_circus.simg' CONTAINER = 'sha1://5ca21c482edaf4b3b689f2af3c719a32567ba21e/2019-07-22/spyking_circus.simg' LOCAL_MODULES = ['../../spikeforest'] recording_dir = mlpr.Input('Directory of recording', directory=True) firings_out = mlpr.Output('Output firings file') detect_sign = mlpr.IntegerParameter(description='-1, 1, or 0') adjacency_radius = mlpr.FloatParameter( optional=True, default=200, description= 'Channel neighborhood adjacency radius corresponding to geom file') detect_threshold = mlpr.FloatParameter( optional=True, default=6, description='Threshold for detection') template_width_ms = mlpr.FloatParameter( optional=True, default=3, description='Spyking circus parameter') filter = mlpr.BoolParameter(optional=True, default=True) whitening_max_elts = mlpr.IntegerParameter( optional=True, default=1000, description= 'I believe it relates to subsampling and affects compute time') clustering_max_elts = mlpr.IntegerParameter( optional=True, default=10000, description= 'I believe it relates to subsampling and affects compute time') def run(self): import spikesorters as sorters print('SpyKING CIRCUS......') recording = SFMdaRecordingExtractor(self.recording_dir) code = ''.join( random.choice(string.ascii_uppercase) for x in range(10)) tmpdir = os.environ.get('TEMPDIR', '/tmp') + '/spyking-circus-' + code num_workers = int(os.environ.get('NUM_WORKERS', '1')) sorter = sorters.SpykingcircusSorter(recording=recording, output_folder=tmpdir, verbose=True, delete_output_folder=True) sorter.set_params( detect_sign=self.detect_sign, adjacency_radius=self.adjacency_radius, detect_threshold=self.detect_threshold, template_width_ms=self.template_width_ms, filter=self.filter, merge_spikes=True, auto_merge=0.5, num_workers=num_workers, electrode_dimensions=None, whitening_max_elts=self.whitening_max_elts, clustering_max_elts=self.clustering_max_elts, ) timer = sorter.run() print('#SF-SORTER-RUNTIME#{:.3f}#'.format(timer)) sorting = sorter.get_result() SFMdaSortingExtractor.write_sorting(sorting=sorting, save_path=self.firings_out)
class MountainSort4Old(mlpr.Processor): NAME = 'MountainSort4' VERSION = '4.3.0' ENVIRONMENT_VARIABLES = [ 'NUM_WORKERS', 'MKL_NUM_THREADS', 'NUMEXPR_NUM_THREADS', 'OMP_NUM_THREADS'] CONTAINER = 'sha1://e06fee7f72f6b66d80d899ebc08e7c39e5a2458e/2019-05-06/mountainsort4.simg' # CONTAINER = 'sha1://8743ff094a26bdedd16f36209a05333f1f82fbd8/2019-06-26/mountainsort4.simg' LOCAL_MODULES = ['../../spikeforest'] recording_dir = mlpr.Input('Directory of recording', directory=True) firings_out = mlpr.Output('Output firings file') detect_sign = mlpr.IntegerParameter( 'Use -1, 0, or 1, depending on the sign of the spikes in the recording') adjacency_radius = mlpr.FloatParameter( 'Use -1 to include all channels in every neighborhood') freq_min = mlpr.FloatParameter( optional=True, default=300, description='Use 0 for no bandpass filtering') freq_max = mlpr.FloatParameter( optional=True, default=6000, description='Use 0 for no bandpass filtering') whiten = mlpr.BoolParameter(optional=True, default=True, description='Whether to do channel whitening as part of preprocessing') clip_size = mlpr.IntegerParameter( optional=True, default=50, description='') detect_threshold = mlpr.FloatParameter( optional=True, default=3, description='') detect_interval = mlpr.IntegerParameter( optional=True, default=10, description='Minimum number of timepoints between events detected on the same channel') noise_overlap_threshold = mlpr.FloatParameter( optional=True, default=0.15, description='Use None for no automated curation') def run(self): from .bandpass_filter import bandpass_filter from .whiten import whiten import ml_ms4alg print('MountainSort4......') recording = SFMdaRecordingExtractor(self.recording_dir) num_workers = os.environ.get('NUM_WORKERS', None) if num_workers: num_workers = int(num_workers) # Bandpass filter if self.freq_min or self.freq_max: recording = bandpass_filter( recording=recording, freq_min=self.freq_min, freq_max=self.freq_max) # Whiten if self.whiten: recording = whiten(recording=recording) # Sort sorting = ml_ms4alg.mountainsort4( recording=recording, detect_sign=self.detect_sign, adjacency_radius=self.adjacency_radius, clip_size=self.clip_size, detect_threshold=self.detect_threshold, detect_interval=self.detect_interval, num_workers=num_workers ) # Curate # if self.noise_overlap_threshold is not None: # sorting=ml_ms4alg.mountainsort4_curation( # recording=recording, # sorting=sorting, # noise_overlap_threshold=self.noise_overlap_threshold # ) SFMdaSortingExtractor.write_sorting( sorting=sorting, save_path=self.firings_out)
class KiloSort2(mlpr.Processor): """ [Prerequisite] 1. MATLAB (Tested on R2019a) 2. CUDA Toolkit v10.0 """ NAME = 'KiloSort2' VERSION = '0.4.4' # wrapper VERSION CONTAINER: Union[str, None] = None recording_dir = mlpr.Input('Directory of recording', directory=True) firings_out = mlpr.Output('Output firings file') detect_sign = mlpr.IntegerParameter( default=-1, optional=True, description= 'Use -1 or 1, depending on the sign of the spikes in the recording') adjacency_radius = mlpr.FloatParameter( default=30, optional=True, description='The sigmaMask for kilosort2') detect_threshold = mlpr.FloatParameter(optional=True, default=6, description='') # prm_template_name=mlpr.StringParameter(optional=False,description='TODO') freq_min = mlpr.FloatParameter( optional=True, default=150, description='Use 0 for no bandpass filtering') pc_per_chan = mlpr.IntegerParameter(optional=True, default=3, description='TODO') minFR = mlpr.FloatParameter( default=1 / 50, optional=True, description= 'minimum spike rate (Hz), if a cluster falls below this for too long it gets removed' ) car = mlpr.BoolParameter( default=True, optional=True, description='whether to do common average referencing') @staticmethod def install(): print('Auto-installing kilosort2.') return install_kilosort2( repo='https://github.com/MouseLand/Kilosort2', commit='5629125f072795b082245f4265b567d3540cbc2b') def run(self): import spikesorters as sorters print('Kilosort2......') try: kilosort2_path = KiloSort2.install() except: traceback.print_exc() raise Exception('Problem installing kilosort.') sorters.Kilosort2Sorter.set_kilosort2_path(kilosort2_path) recording = SFMdaRecordingExtractor(self.recording_dir) code = ''.join( random.choice(string.ascii_uppercase) for x in range(10)) tmpdir = os.environ.get('TEMPDIR', '/tmp') + '/kilosort2-' + code sorter = sorters.Kilosort2Sorter(recording=recording, output_folder=tmpdir, debug=True, delete_output_folder=True) sorter.set_params(detect_threshold=self.detect_threshold, car=self.car, minFR=self.minFR, electrode_dimensions=None, freq_min=self.freq_min, sigmaMask=self.adjacency_radius, nPCs=self.pc_per_chan) timer = sorter.run() print('#SF-SORTER-RUNTIME#{:.3f}#'.format(timer)) sorting = sorter.get_result() SFMdaSortingExtractor.write_sorting(sorting=sorting, save_path=self.firings_out)
class SpykingCircus(mlpr.Processor): NAME = 'SpykingCircus' VERSION = '0.2.2' ENVIRONMENT_VARIABLES = [ 'NUM_WORKERS', 'MKL_NUM_THREADS', 'NUMEXPR_NUM_THREADS', 'OMP_NUM_THREADS' ] ADDITIONAL_FILES = ['*.params'] CONTAINER = 'sha1://914becce45aec56a84dd1dd4bca4037b09c50373/02-12-2019/spyking_circus.simg' CONTAINER_SHARE_ID = '69432e9201d0' # place to look for container recording_dir = mlpr.Input('Directory of recording', directory=True) channels = mlpr.IntegerListParameter( description='List of channels to use.', optional=True, default=[]) firings_out = mlpr.Output('Output firings file') detect_sign = mlpr.IntegerParameter(description='-1, 1, or 0') adjacency_radius = mlpr.FloatParameter( optional=True, default=100, description= 'Channel neighborhood adjacency radius corresponding to geom file') spike_thresh = mlpr.FloatParameter(optional=True, default=6, description='Threshold for detection') template_width_ms = mlpr.FloatParameter( optional=True, default=3, description='Spyking circus parameter') filter = mlpr.BoolParameter(optional=True, default=True) whitening_max_elts = mlpr.IntegerParameter( optional=True, default=1000, description= 'I believe it relates to subsampling and affects compute time') clustering_max_elts = mlpr.IntegerParameter( optional=True, default=10000, description= 'I believe it relates to subsampling and affects compute time') def run(self): code = ''.join( random.choice(string.ascii_uppercase) for x in range(10)) tmpdir = os.environ.get('TEMPDIR', '/tmp') + '/spyking-circus-tmp-' + code num_workers = os.environ.get('NUM_WORKERS', 1) try: recording = se.MdaRecordingExtractor(self.recording_dir) if len(self.channels) > 0: recording = se.SubRecordingExtractor( parent_recording=recording, channel_ids=self.channels) if not os.path.exists(tmpdir): os.mkdir(tmpdir) sorting = spyking_circus( recording=recording, output_folder=tmpdir, probe_file=None, file_name=None, detect_sign=self.detect_sign, adjacency_radius=self.adjacency_radius, spike_thresh=self.spike_thresh, template_width_ms=self.template_width_ms, filter=self.filter, merge_spikes=True, n_cores=num_workers, electrode_dimensions=None, whitening_max_elts=self.whitening_max_elts, clustering_max_elts=self.clustering_max_elts, ) se.MdaSortingExtractor.writeSorting(sorting=sorting, save_path=self.firings_out) except: if not getattr(self, '_keep_temp_files', False): if os.path.exists(tmpdir): shutil.rmtree(tmpdir) raise if not getattr(self, '_keep_temp_files', False): shutil.rmtree(tmpdir)
class KiloSort2Old(mlpr.Processor): """ KiloSort2 wrapper for SpikeForest framework written by J. James Jun, May 7, 2019 modified from `spiketoolkit/sorters/Kilosort` to be made compatible with SpikeForest [Prerequisite] 1. MATLAB (Tested on R2018b) 2. CUDA Toolkit v9.1 [Installation instruction in SpikeForest environment] 1. Run `git clone https://github.com/alexmorley/Kilosort2.git` Kilosort2 currently doesn't work on tetrodes and low-channel count probes (as of May 7, 2019). Clone from Alex Morley's repository that fixed these issues. Original Kilosort2 code can be obtained from `https://github.com/MouseLand/Kilosort2.git` 2. (optional) If Alex Morley's latest version doesn't work with SpikeForest, run `git checkout 43cbbfff89b9c88cdeb147ffd4ac35bfde9c7956` 3. In Matlab, run `CUDA/mexGPUall` to compile all CUDA codes 4. Add `KILOSORT2_PATH=...` in your .bashrc file. """ NAME = 'KiloSort2' VERSION = '0.3.3' # wrapper VERSION ADDITIONAL_FILES = ['*.m'] ENVIRONMENT_VARIABLES = [ 'NUM_WORKERS', 'MKL_NUM_THREADS', 'NUMEXPR_NUM_THREADS', 'OMP_NUM_THREADS' ] CONTAINER = None recording_dir = mlpr.Input('Directory of recording', directory=True) channels = mlpr.IntegerListParameter( description='List of channels to use.', optional=True, default=[]) firings_out = mlpr.Output('Output firings file') detect_sign = mlpr.IntegerParameter( default=-1, optional=True, description= 'Use -1 or 1, depending on the sign of the spikes in the recording') adjacency_radius = mlpr.FloatParameter( default=30, optional=True, description='Use -1 to include all channels in every neighborhood') detect_threshold = mlpr.FloatParameter(optional=True, default=6, description='') # prm_template_name=mlpr.StringParameter(optional=False,description='TODO') freq_min = mlpr.FloatParameter( optional=True, default=150, description='Use 0 for no bandpass filtering') freq_max = mlpr.FloatParameter( optional=True, default=6000, description='Use 0 for no bandpass filtering') merge_thresh = mlpr.FloatParameter(optional=True, default=0.98, description='TODO') pc_per_chan = mlpr.IntegerParameter(optional=True, default=3, description='TODO') minFR = mlpr.FloatParameter( default=1 / 50, optional=True, description= 'minimum spike rate (Hz), if a cluster falls below this for too long it gets removed' ) @staticmethod def install(): print('Auto-installing kilosort.') return install_kilosort2( repo='https://github.com/alexmorley/Kilosort2', commit='43cbbfff89b9c88cdeb147ffd4ac35bfde9c7956') def run(self): _keep_temp_files = True code = ''.join( random.choice(string.ascii_uppercase) for x in range(10)) tmpdir = os.environ.get('TEMPDIR', '/tmp') + '/kilosort2-tmp-' + code try: recording = SFMdaRecordingExtractor(self.recording_dir) if len(self.channels) > 0: recording = se.SubRecordingExtractor( parent_recording=recording, channel_ids=self.channels) if not os.path.exists(tmpdir): os.mkdir(tmpdir) sorting = kilosort2_helper(recording=recording, tmpdir=tmpdir, detect_sign=self.detect_sign, adjacency_radius=self.adjacency_radius, detect_threshold=self.detect_threshold, merge_thresh=self.merge_thresh, freq_min=self.freq_min, freq_max=self.freq_max, pc_per_chan=self.pc_per_chan, minFR=self.minFR) SFMdaSortingExtractor.write_sorting(sorting=sorting, save_path=self.firings_out) except: if os.path.exists(tmpdir): if not _keep_temp_files: print('removing tmpdir1') shutil.rmtree(tmpdir) raise if not _keep_temp_files: print('removing tmpdir2') shutil.rmtree(tmpdir)
class MountainSort4(mlpr.Processor): NAME = 'MountainSort4' VERSION = '4.3.1' ENVIRONMENT_VARIABLES = [ 'NUM_WORKERS', 'MKL_NUM_THREADS', 'NUMEXPR_NUM_THREADS', 'OMP_NUM_THREADS'] # CONTAINER = 'sha1://e06fee7f72f6b66d80d899ebc08e7c39e5a2458e/2019-05-06/mountainsort4.simg' CONTAINER = 'sha1://8743ff094a26bdedd16f36209a05333f1f82fbd8/2019-06-26/mountainsort4.simg' LOCAL_MODULES = ['../../spikeforest'] recording_dir = mlpr.Input('Directory of recording', directory=True) firings_out = mlpr.Output('Output firings file') detect_sign = mlpr.IntegerParameter( 'Use -1, 0, or 1, depending on the sign of the spikes in the recording') adjacency_radius = mlpr.FloatParameter( 'Use -1 to include all channels in every neighborhood') freq_min = mlpr.FloatParameter( optional=True, default=300, description='Use 0 for no bandpass filtering') freq_max = mlpr.FloatParameter( optional=True, default=6000, description='Use 0 for no bandpass filtering') whiten = mlpr.BoolParameter(optional=True, default=True, description='Whether to do channel whitening as part of preprocessing') clip_size = mlpr.IntegerParameter( optional=True, default=50, description='') detect_threshold = mlpr.FloatParameter( optional=True, default=3, description='') detect_interval = mlpr.IntegerParameter( optional=True, default=10, description='Minimum number of timepoints between events detected on the same channel') noise_overlap_threshold = mlpr.FloatParameter( optional=True, default=0.15, description='Use None for no automated curation') def run(self): # from spikeinterface/spikesorters import spikesorters as sorters print('MountainSort4......') recording = SFMdaRecordingExtractor(self.recording_dir) num_workers = os.environ.get('NUM_WORKERS', None) if num_workers: num_workers = int(num_workers) code = ''.join(random.choice(string.ascii_uppercase) for x in range(10)) tmpdir = os.environ.get('TEMPDIR', '/tmp') + '/mountainsort4-' + code sorter = sorters.Mountainsort4Sorter( recording=recording, output_folder=tmpdir, debug=True, delete_output_folder=True ) sorter.set_params( detect_sign=self.detect_sign, adjacency_radius=self.adjacency_radius, clip_size=self.clip_size, detect_threshold=self.detect_threshold, detect_interval=self.detect_interval, num_workers=num_workers, curation=False, whiten=True, filter=True, freq_min=self.freq_min, freq_max=self.freq_max ) # TODO: get elapsed time from the return of this run sorter.run() sorting = sorter.get_result() SFMdaSortingExtractor.write_sorting( sorting=sorting, save_path=self.firings_out)
class CustomSorting(mlpr.Processor): NAME = 'CustomSorting' VERSION = '0.1.7' # the version can be incremented when the code inside run() changes # input files recording_file_in = mlpr.Input('Path to raw.mda') geom_in = mlpr.Input('Path to geom.csv', optional=True) # output files firings_out = mlpr.Output('Output firings.mda file') firings_curated_out = mlpr.Output('Output firings.curated.mda file') metrics_out = mlpr.Output('Metrics .json output') # parameters samplerate = mlpr.FloatParameter("Sampling frequency") mask_out_artifacts = mlpr.BoolParameter( optional=True, default=False, description='Whether to mask out artifacts') freq_min = mlpr.FloatParameter( optional=True, default=300, description='Use 0 for no bandpass filtering') freq_max = mlpr.FloatParameter( optional=True, default=6000, description='Use 0 for no bandpass filtering') whiten = mlpr.BoolParameter( optional=True, default=True, description='Whether to do channel whitening as part of preprocessing') detect_sign = mlpr.IntegerParameter( 'Use -1, 0, or 1, depending on the sign of the spikes in the recording') adjacency_radius = mlpr.FloatParameter( 'Use -1 to include all channels in every neighborhood') clip_size = mlpr.IntegerParameter( optional=True, default=50, description='') detect_threshold = mlpr.FloatParameter( optional=True, default=3, description='') detect_interval = mlpr.IntegerParameter( optional=True, default=10, description='Minimum number of timepoints between events detected on the same channel') noise_overlap_threshold = mlpr.FloatParameter( optional=True, default=0.15, description='Use None for no automated curation') def run(self): # This temporary file will automatically be removed even in the case of a python exception with TemporaryDirectory() as tmpdir: # names of files for the temporary/intermediate data filt = tmpdir + '/filt.mda' filt2 = tmpdir + '/filt2.mda' pre = tmpdir + '/pre.mda' print('Bandpass filtering raw -> filt...') _bandpass_filter(self.recording_file_in, filt) if self.mask_out_artifacts: print('Masking out artifacts filt -> filt2...') _mask_out_artifacts(filt, filt2) else: print('Copying filt -> filt2...') filt2 = filt if self.whiten: print('Whitening filt2 -> pre...') _whiten(filt2, pre) else: pre = filt2 # read the preprocessed timeseries into RAM (maybe we'll do it differently later) X = sf.mdaio.readmda(pre) # handle the geom if type(self.geom_in) == str: print('Using geom.csv from a file', self.geom_in) geom = read_geom_csv(self.geom_in) else: # no geom file was provided as input num_channels = X.shape[0] if num_channels > 6: raise Exception( 'For more than six channels, we require that a geom.csv be provided') # otherwise make a trivial geometry file print('Making a trivial geom file.') geom = np.zeros((X.shape[0], 2)) # Now represent the preprocessed recording using a RecordingExtractor recording = se.NumpyRecordingExtractor( X, samplerate=30000, geom=geom) # hard-code this for now -- idea: run many simultaneous jobs, each using only 2 cores # important to set certain environment variables in the .sh script that calls this .py script num_workers = 2 # Call MountainSort4 sorting = ml_ms4alg.mountainsort4( recording=recording, detect_sign=self.detect_sign, adjacency_radius=self.adjacency_radius, clip_size=self.clip_size, detect_threshold=self.detect_threshold, detect_interval=self.detect_interval, num_workers=num_workers, ) # Write the firings.mda print('Writing firings.mda...') sf.SFMdaSortingExtractor.write_sorting( sorting=sorting, save_path=self.firings_out) print('Computing cluster metrics...') cluster_metrics_path = tmpdir + '/cluster_metrics.json' _cluster_metrics(pre, self.firings_out, cluster_metrics_path) print('Computing isolation metrics...') isolation_metrics_path = tmpdir + '/isolation_metrics.json' pair_metrics_path = tmpdir + '/pair_metrics.json' _isolation_metrics(pre, self.firings_out, isolation_metrics_path, pair_metrics_path) print('Combining metrics...') metrics_path = tmpdir + '/metrics.json' _combine_metrics(cluster_metrics_path, isolation_metrics_path, metrics_path) # copy metrics.json to the output location shutil.copy(metrics_path, self.metrics_out) print('Creating label map...') label_map_path = tmpdir + '/label_map.mda' create_label_map(metrics=metrics_path, label_map_out=label_map_path) print('Applying label map...') apply_label_map(firings=self.firings_out, label_map=label_map_path, firings_out=self.firings_curated_out)
class IronClust(mlpr.Processor): NAME = 'IronClust' VERSION = '0.7.9' ENVIRONMENT_VARIABLES = [ 'NUM_WORKERS', 'MKL_NUM_THREADS', 'NUMEXPR_NUM_THREADS', 'OMP_NUM_THREADS', 'TEMPDIR'] CONTAINER: Union[str, None] = None recording_dir = mlpr.Input('Directory of recording', directory=True) firings_out = mlpr.Output('Output firings file') channels = mlpr.IntegerListParameter( optional=True, default=[], description='List of channels to use.' ) detect_sign = mlpr.IntegerParameter( optional=True, default=-1, description='Use -1, 0, or 1, depending on the sign of the spikes in the recording' ) adjacency_radius = mlpr.FloatParameter( optional=True, default=50, description='Use -1 to include all channels in every neighborhood' ) adjacency_radius_out = mlpr.FloatParameter( optional=True, default=100, description='Use -1 to include all channels in every neighborhood' ) detect_threshold = mlpr.FloatParameter( optional=True, default=4, description='detection threshold' ) prm_template_name = mlpr.StringParameter( optional=True, default='', description='.prm template file name' ) freq_min = mlpr.FloatParameter( optional=True, default=300, description='Use 0 for no bandpass filtering' ) freq_max = mlpr.FloatParameter( optional=True, default=8000, description='Use 0 for no bandpass filtering' ) merge_thresh = mlpr.FloatParameter( optional=True, default=0.98, description='Threshold for automated merging' ) pc_per_chan = mlpr.IntegerParameter( optional=True, default=0, description='Number of principal components per channel' ) # added in version 0.2.4 whiten = mlpr.BoolParameter( optional=True, default=False, description='Whether to do channel whitening as part of preprocessing') filter_type = mlpr.StringParameter( optional=True, default='bandpass', description='{none, bandpass, wiener, fftdiff, ndiff}') filter_detect_type = mlpr.StringParameter( optional=True, default='none', description='{none, bandpass, wiener, fftdiff, ndiff}') common_ref_type = mlpr.StringParameter( optional=True, default='mean', description='{none, mean, median}') batch_sec_drift = mlpr.FloatParameter( optional=True, default=300, description='batch duration in seconds. clustering time duration') step_sec_drift = mlpr.FloatParameter( optional=True, default=20, description='compute anatomical similarity every n sec') knn = mlpr.IntegerParameter( optional=True, default=30, description='K nearest neighbors') min_count = mlpr.IntegerParameter( optional=True, default=30, description='Minimum cluster size') fGpu = mlpr.BoolParameter( optional=True, default=True, description='Use GPU if available') fft_thresh = mlpr.FloatParameter( optional=True, default=8, description='FFT-based noise peak threshold') fft_thresh_low = mlpr.FloatParameter( optional=True, default=0, description='FFT-based noise peak lower threshold (set to 0 to disable dual thresholding scheme)') nSites_whiten = mlpr.IntegerParameter( optional=True, default=32, description='Number of adjacent channels to whiten') feature_type = mlpr.StringParameter( optional=True, default='gpca', description='{gpca, pca, vpp, vmin, vminmax, cov, energy, xcov}') delta_cut = mlpr.FloatParameter( optional=True, default=1, description='Cluster detection threshold (delta-cutoff)') post_merge_mode = mlpr.IntegerParameter( optional=True, default=1, description='post_merge_mode') sort_mode = mlpr.IntegerParameter( optional=True, default=1, description='sort_mode') @staticmethod def install(): print('Auto-installing ironclust.') return install_ironclust(commit='72bb6d097e0875d6cfe52bddf5f782e667e1b042') def run(self): timer = time.time() import spikesorters as sorters print('IronClust......') try: ironclust_path = IronClust.install() except: traceback.print_exc() raise Exception('Problem installing ironclust.') sorters.IronClustSorter.set_ironclust_path(ironclust_path) recording = SFMdaRecordingExtractor(self.recording_dir) code = ''.join(random.choice(string.ascii_uppercase) for x in range(10)) tmpdir = os.environ.get('TEMPDIR', '/tmp') + '/ironclust-' + code sorter = sorters.IronClustSorter( recording=recording, output_folder=tmpdir, debug=True, delete_output_folder = False # will be taken care by _keep_temp_files one step above ) sorter.set_params( detect_sign=self.detect_sign, adjacency_radius=self.adjacency_radius, adjacency_radius_out=self.adjacency_radius_out, detect_threshold=self.detect_threshold, prm_template_name=self.prm_template_name, freq_min=self.freq_min, freq_max=self.freq_max, merge_thresh=self.merge_thresh, pc_per_chan=self.pc_per_chan, whiten=self.whiten, filter_type=self.filter_type, filter_detect_type=self.filter_detect_type, common_ref_type=self.common_ref_type, batch_sec_drift=self.batch_sec_drift, step_sec_drift=self.step_sec_drift, knn=self.knn, min_count=self.min_count, fGpu=self.fGpu, fft_thresh=self.fft_thresh, fft_thresh_low=self.fft_thresh_low, nSites_whiten=self.nSites_whiten, feature_type=self.feature_type, delta_cut=self.delta_cut, post_merge_mode=self.post_merge_mode, sort_mode=self.sort_mode ) timer = sorter.run() print('#SF-SORTER-RUNTIME#{:.3f}#'.format(timer)) sorting = sorter.get_result() SFMdaSortingExtractor.write_sorting( sorting=sorting, save_path=self.firings_out)
class KiloSortOld(mlpr.Processor): """ Kilosort wrapper for SpikeForest framework written by J. James Jun, May 21, 2019 [Prerequisite] 1. MATLAB (Tested on R2019a) 2. CUDA Toolkit v10.0 [Optional: Installation instruction in SpikeForest environment] 1. Run `git clone https://github.com/cortex-lab/KiloSort.git` 3. In Matlab, run `CUDA/mexGPUall` to compile all CUDA codes 4. Add `KILOSORT_PATH_DEV=...` in your .bashrc file. """ NAME = 'KiloSort' VERSION = '0.2.4' # wrapper VERSION ADDITIONAL_FILES = ['*.m'] ENVIRONMENT_VARIABLES = [ 'NUM_WORKERS', 'MKL_NUM_THREADS', 'NUMEXPR_NUM_THREADS', 'OMP_NUM_THREADS' ] CONTAINER = None LOCAL_MODULES = ['../../spikeforest'] recording_dir = mlpr.Input('Directory of recording', directory=True) channels = mlpr.IntegerListParameter( description='List of channels to use.', optional=True, default=[]) firings_out = mlpr.Output('Output firings file') detect_sign = mlpr.IntegerParameter( optional=True, default=-1, description= 'Use -1 or 1, depending on the sign of the spikes in the recording') adjacency_radius = mlpr.FloatParameter(optional=True, default=-1, description='Currently unused') detect_threshold = mlpr.FloatParameter(optional=True, default=3, description='') # prm_template_name=mlpr.StringParameter(optional=False,description='TODO') freq_min = mlpr.FloatParameter( optional=True, default=300, description='Use 0 for no bandpass filtering') freq_max = mlpr.FloatParameter( optional=True, default=6000, description='Use 0 for no bandpass filtering') merge_thresh = mlpr.FloatParameter(optional=True, default=0.98, description='TODO') pc_per_chan = mlpr.IntegerParameter(optional=True, default=3, description='TODO') @staticmethod def install(): print('Auto-installing kilosort.') return install_kilosort( repo='https://github.com/cortex-lab/KiloSort.git', commit='3f33771f8fdf8c3846a7f8a75cc8c318b44ed48c') def run(self): keep_temp_files = False code = ''.join( random.choice(string.ascii_uppercase) for x in range(10)) tmpdir = os.environ.get('TEMPDIR', '/tmp') + '/kilosort-tmp-' + code try: recording = SFMdaRecordingExtractor(self.recording_dir) if len(self.channels) > 0: recording = se.SubRecordingExtractor( parent_recording=recording, channel_ids=self.channels) if not os.path.exists(tmpdir): os.mkdir(tmpdir) sorting = kilosort_helper(recording=recording, tmpdir=tmpdir, detect_sign=self.detect_sign, adjacency_radius=self.adjacency_radius, detect_threshold=self.detect_threshold, merge_thresh=self.merge_thresh, freq_min=self.freq_min, freq_max=self.freq_max, pc_per_chan=self.pc_per_chan) SFMdaSortingExtractor.write_sorting(sorting=sorting, save_path=self.firings_out) except: if os.path.exists(tmpdir): if not keep_temp_files: shutil.rmtree(tmpdir) raise if not getattr(self, '_keep_temp_files', False): shutil.rmtree(tmpdir)
class Klusta(mlpr.Processor): """ Installation instruction >>> pip install Cython h5py tqdm >>> pip install click klusta klustakwik2 More information on klusta at: * https://github.com/kwikteam/phy" * https://github.com/kwikteam/klusta """ NAME = 'Klusta' VERSION = '0.2.2' # wrapper VERSION ENVIRONMENT_VARIABLES = [ 'NUM_WORKERS', 'MKL_NUM_THREADS', 'NUMEXPR_NUM_THREADS', 'OMP_NUM_THREADS', 'TEMPDIR'] # CONTAINER = 'sha1://6d76f22e3b4eff52b430ef4649a8802f7da9e0ec/2019-05-13/klusta.simg' CONTAINER = 'sha1://182ff734d38e2ece30ed751de55807b0a8359959/2019-06-28/klusta.simg' LOCAL_MODULES = ['../../spikeforest'] recording_dir = mlpr.Input('Directory of recording', directory=True) firings_out = mlpr.Output('Output firings file') adjacency_radius = mlpr.FloatParameter(optional=True, default=None, description='') detect_sign = mlpr.FloatParameter(optional=True, default=-1, description='') threshold_strong_std_factor = mlpr.FloatParameter(optional=True, default=5, description='') threshold_weak_std_factor = mlpr.FloatParameter(optional=True, default=2, description='') n_features_per_channel = mlpr.IntegerParameter(optional=True, default=3, description='') num_starting_clusters = mlpr.IntegerParameter(optional=True, default=3, description='') extract_s_before = mlpr.IntegerParameter(optional=True, default=16, description='') extract_s_after = mlpr.IntegerParameter(optional=True, default=32, description='') def run(self): import spikesorters as sorters print('Klusta......') recording = SFMdaRecordingExtractor(self.recording_dir) code = ''.join(random.choice(string.ascii_uppercase) for x in range(10)) tmpdir = os.environ.get('TEMPDIR', '/tmp') + '/klusta-' + code sorter = sorters.KlustaSorter( recording=recording, output_folder=tmpdir, debug=True, delete_output_folder=True ) sorter.set_params( adjacency_radius=self.adjacency_radius, detect_sign=self.detect_sign, threshold_strong_std_factor=self.threshold_strong_std_factor, threshold_weak_std_factor=self.threshold_weak_std_factor, n_features_per_channel=self.n_features_per_channel, num_starting_clusters=self.num_starting_clusters, extract_s_before=self.extract_s_before, extract_s_after=self.extract_s_after ) timer = sorter.run() print('#SF-SORTER-RUNTIME#{:.3f}#'.format(timer)) sorting = sorter.get_result() SFMdaSortingExtractor.write_sorting( sorting=sorting, save_path=self.firings_out)
class MountainSort4(mlpr.Processor): NAME = 'MountainSort4' VERSION = '4.2.0' ENVIRONMENT_VARIABLES = [ 'NUM_WORKERS', 'MKL_NUM_THREADS', 'NUMEXPR_NUM_THREADS', 'OMP_NUM_THREADS' ] CONTAINER = 'sha1://009406add7a55687cec176be912bc7685c2a4b1d/02-12-2019/mountainsort4.simg' CONTAINER_SHARE_ID = '69432e9201d0' # place to look for container recording_dir = mlpr.Input('Directory of recording', directory=True) firings_out = mlpr.Output('Output firings file') detect_sign = mlpr.IntegerParameter( 'Use -1, 0, or 1, depending on the sign of the spikes in the recording' ) adjacency_radius = mlpr.FloatParameter( 'Use -1 to include all channels in every neighborhood') freq_min = mlpr.FloatParameter( optional=True, default=300, description='Use 0 for no bandpass filtering') freq_max = mlpr.FloatParameter( optional=True, default=6000, description='Use 0 for no bandpass filtering') whiten = mlpr.BoolParameter( optional=True, default=True, description='Whether to do channel whitening as part of preprocessing') clip_size = mlpr.IntegerParameter(optional=True, default=50, description='') detect_threshold = mlpr.FloatParameter(optional=True, default=3, description='') detect_interval = mlpr.IntegerParameter( optional=True, default=10, description= 'Minimum number of timepoints between events detected on the same channel' ) noise_overlap_threshold = mlpr.FloatParameter( optional=True, default=0.15, description='Use None for no automated curation') def run(self): import spikeextractors as se import spiketoolkit as st import ml_ms4alg print('MountainSort4......') recording = se.MdaRecordingExtractor(self.recording_dir) num_workers = os.environ.get('NUM_WORKERS', None) if num_workers: num_workers = int(num_workers) # Bandpass filter if self.freq_min or self.freq_max: recording = st.preprocessing.bandpass_filter( recording=recording, freq_min=self.freq_min, freq_max=self.freq_max) # Whiten if self.whiten: recording = st.preprocessing.whiten(recording=recording) # Sort sorting = ml_ms4alg.mountainsort4( recording=recording, detect_sign=self.detect_sign, adjacency_radius=self.adjacency_radius, clip_size=self.clip_size, detect_threshold=self.detect_threshold, detect_interval=self.detect_interval, num_workers=num_workers) # Curate # if self.noise_overlap_threshold is not None: # sorting=ml_ms4alg.mountainsort4_curation( # recording=recording, # sorting=sorting, # noise_overlap_threshold=self.noise_overlap_threshold # ) se.MdaSortingExtractor.writeSorting(sorting=sorting, save_path=self.firings_out)