class IronClust(mlpr.Processor): NAME='IronClust' VERSION='4.2.6' recording_dir=mlpr.Input('Directory of recording',directory=True) channels=mlpr.IntegerListParameter(description='List of channels to use.',optional=True,default=[]) firings_out=mlpr.Output('Output firings file') detect_sign=mlpr.IntegerParameter('Use -1, 0, or 1, depending on the sign of the spikes in the recording') adjacency_radius=mlpr.FloatParameter('Use -1 to include all channels in every neighborhood') detect_threshold=mlpr.FloatParameter(optional=True,default=3,description='') prm_template_name=mlpr.StringParameter(optional=False,description='TODO') freq_min=mlpr.FloatParameter(optional=True,default=300,description='Use 0 for no bandpass filtering') freq_max=mlpr.FloatParameter(optional=True,default=6000,description='Use 0 for no bandpass filtering') merge_thresh=mlpr.FloatParameter(optional=True,default=0.98,description='TODO') pc_per_chan=mlpr.IntegerParameter(optional=True,default=3,description='TODO') def run(self): ironclust_src=os.environ.get('IRONCLUST_SRC',None) if not ironclust_src: raise Exception('Environment variable not set: IRONCLUST_SRC') code=''.join(random.choice(string.ascii_uppercase) for x in range(10)) tmpdir=os.environ.get('TEMPDIR','/tmp')+'/ironclust-tmp-'+code try: recording=si.MdaRecordingExtractor(self.recording_dir) if len(self.channels)>0: recording=si.SubRecordingExtractor(parent_recording=recording,channel_ids=self.channels) if not os.path.exists(tmpdir): os.mkdir(tmpdir) sorting=sf.sorters.ironclust( recording=recording, tmpdir=tmpdir, ## TODO detect_sign=self.detect_sign, adjacency_radius=self.adjacency_radius, detect_threshold=self.detect_threshold, merge_thresh=self.merge_thresh, freq_min=self.freq_min, freq_max=self.freq_max, pc_per_chan=self.pc_per_chan, prm_template_name=self.prm_template_name, ironclust_src=ironclust_src ) si.MdaSortingExtractor.writeSorting(sorting=sorting,save_path=self.firings_out) except: if os.path.exists(tmpdir): shutil.rmtree(tmpdir) raise shutil.rmtree(tmpdir)
class MeanShift(mlpr.Processor): VERSION = '0.1.0' data = mlpr.Input() labels_out = mlpr.Output(is_array=True) bandwidth = mlpr.StringParameter() def run(self): from sklearn.cluster import MeanShift import numpy as np if self.bandwidth == 'auto': bandwidth = None else: bandwidth = float(self.bandwidth) A = MeanShift(bandwidth=bandwidth).fit(np.load(self.data)) np.save(self.labels_out + '.npy', A.labels_) os.rename(self.labels_out + '.npy', self.labels_out)
class H5ToDict(mlpr.Processor): NAME = 'H5ToDict' VERSION = '0.1.1' # Inputs h5_in = mlpr.Input() # Parameters upload_to = mlpr.StringParameter(optional=True, default='') # Outputs json_out = mlpr.Output() def run(self): upload_to = self.upload_to if not upload_to: upload_to = None x = h5_to_dict(self.h5_in, upload_to=upload_to) with open(self.json_out, 'w') as f: simplejson.dump(x, f, ignore_nan=True)
class JRClust(mlpr.Processor): """ JRClust v4 wrapper written by J. James Jun, May 8, 2019 modified from `spikeforest/spikesorters/ironclust/ironclust.py` [Installation instruction in SpikeForest environment] 1. Run `git clone https://github.com/JaneliaSciComp/JRCLUST` 2. Activate conda environment for SpikeForest 3. Create `JRCLUST_PATH` and `MDAIO_PATH` 4. Flatiron execution: `module load matlab/R2019a cuda/10.0.130_410.48` DO NOT LOAD `module load gcc`. DO NOT USE `matlab/R2018b` (it will crash while creating a parpool). See: James Jun, et al. Real-time spike sorting platform for high-density extracellular probes with ground-truth validation and drift correction https://github.com/JaneliaSciComp/JRCLUST """ NAME = 'JRClust' VERSION = '0.1.3' ENVIRONMENT_VARIABLES = [ 'NUM_WORKERS', 'MKL_NUM_THREADS', 'NUMEXPR_NUM_THREADS', 'OMP_NUM_THREADS', 'TEMPDIR' ] ADDITIONAL_FILES = ['*.m', '*.prm'] CONTAINER = None recording_dir = mlpr.Input('Directory of recording', directory=True) channels = mlpr.IntegerListParameter( description='List of channels to use.', optional=True, default=[]) firings_out = mlpr.Output('Output firings file') detect_sign = mlpr.IntegerParameter( optional=True, default=-1, description= 'Use -1, 0, or 1, depending on the sign of the spikes in the recording' ) adjacency_radius = mlpr.FloatParameter(optional=True, default=50, description='') detect_threshold = mlpr.FloatParameter(optional=True, default=4.5, description='detection threshold') freq_min = mlpr.FloatParameter( optional=True, default=300, description='Use 0 for no bandpass filtering') freq_max = mlpr.FloatParameter( optional=True, default=3000, description='Use 0 for no bandpass filtering') merge_thresh = mlpr.FloatParameter( optional=True, default=0.98, description='Threshold for automated merging') pc_per_chan = mlpr.IntegerParameter( optional=True, default=1, description='Number of principal components per channel') # added in version 0.2.4 filter_type = mlpr.StringParameter( optional=True, default='bandpass', description='{none, bandpass, wiener, fftdiff, ndiff}') nDiffOrder = mlpr.FloatParameter(optional=True, default=2, description='') common_ref_type = mlpr.StringParameter(optional=True, default='none', description='{none, mean, median}') min_count = mlpr.IntegerParameter(optional=True, default=30, description='Minimum cluster size') fGpu = mlpr.IntegerParameter(optional=True, default=1, description='Use GPU if available') fParfor = mlpr.IntegerParameter(optional=True, default=0, description='Use parfor if available') feature_type = mlpr.StringParameter( optional=True, default='gpca', description='{gpca, pca, vpp, vmin, vminmax, cov, energy, xcov}') def run(self): tmpdir = _get_tmpdir('jrclust') try: recording = SFMdaRecordingExtractor(self.recording_dir) params = read_dataset_params(self.recording_dir) if len(self.channels) > 0: recording = se.SubRecordingExtractor( parent_recording=recording, channel_ids=self.channels) if not os.path.exists(tmpdir): os.mkdir(tmpdir) all_params = dict() for param0 in self.PARAMETERS: all_params[param0.name] = getattr(self, param0.name) sorting = jrclust_helper( recording=recording, tmpdir=tmpdir, params=params, **all_params, ) SFMdaSortingExtractor.write_sorting(sorting=sorting, save_path=self.firings_out) except: if os.path.exists(tmpdir): if not getattr(self, '_keep_temp_files', False): shutil.rmtree(tmpdir) raise if not getattr(self, '_keep_temp_files', False): shutil.rmtree(tmpdir)
class IronClust(mlpr.Processor): NAME = 'IronClust' VERSION = '0.7.9' ENVIRONMENT_VARIABLES = [ 'NUM_WORKERS', 'MKL_NUM_THREADS', 'NUMEXPR_NUM_THREADS', 'OMP_NUM_THREADS', 'TEMPDIR'] CONTAINER: Union[str, None] = None recording_dir = mlpr.Input('Directory of recording', directory=True) firings_out = mlpr.Output('Output firings file') channels = mlpr.IntegerListParameter( optional=True, default=[], description='List of channels to use.' ) detect_sign = mlpr.IntegerParameter( optional=True, default=-1, description='Use -1, 0, or 1, depending on the sign of the spikes in the recording' ) adjacency_radius = mlpr.FloatParameter( optional=True, default=50, description='Use -1 to include all channels in every neighborhood' ) adjacency_radius_out = mlpr.FloatParameter( optional=True, default=100, description='Use -1 to include all channels in every neighborhood' ) detect_threshold = mlpr.FloatParameter( optional=True, default=4, description='detection threshold' ) prm_template_name = mlpr.StringParameter( optional=True, default='', description='.prm template file name' ) freq_min = mlpr.FloatParameter( optional=True, default=300, description='Use 0 for no bandpass filtering' ) freq_max = mlpr.FloatParameter( optional=True, default=8000, description='Use 0 for no bandpass filtering' ) merge_thresh = mlpr.FloatParameter( optional=True, default=0.98, description='Threshold for automated merging' ) pc_per_chan = mlpr.IntegerParameter( optional=True, default=0, description='Number of principal components per channel' ) # added in version 0.2.4 whiten = mlpr.BoolParameter( optional=True, default=False, description='Whether to do channel whitening as part of preprocessing') filter_type = mlpr.StringParameter( optional=True, default='bandpass', description='{none, bandpass, wiener, fftdiff, ndiff}') filter_detect_type = mlpr.StringParameter( optional=True, default='none', description='{none, bandpass, wiener, fftdiff, ndiff}') common_ref_type = mlpr.StringParameter( optional=True, default='mean', description='{none, mean, median}') batch_sec_drift = mlpr.FloatParameter( optional=True, default=300, description='batch duration in seconds. clustering time duration') step_sec_drift = mlpr.FloatParameter( optional=True, default=20, description='compute anatomical similarity every n sec') knn = mlpr.IntegerParameter( optional=True, default=30, description='K nearest neighbors') min_count = mlpr.IntegerParameter( optional=True, default=30, description='Minimum cluster size') fGpu = mlpr.BoolParameter( optional=True, default=True, description='Use GPU if available') fft_thresh = mlpr.FloatParameter( optional=True, default=8, description='FFT-based noise peak threshold') fft_thresh_low = mlpr.FloatParameter( optional=True, default=0, description='FFT-based noise peak lower threshold (set to 0 to disable dual thresholding scheme)') nSites_whiten = mlpr.IntegerParameter( optional=True, default=32, description='Number of adjacent channels to whiten') feature_type = mlpr.StringParameter( optional=True, default='gpca', description='{gpca, pca, vpp, vmin, vminmax, cov, energy, xcov}') delta_cut = mlpr.FloatParameter( optional=True, default=1, description='Cluster detection threshold (delta-cutoff)') post_merge_mode = mlpr.IntegerParameter( optional=True, default=1, description='post_merge_mode') sort_mode = mlpr.IntegerParameter( optional=True, default=1, description='sort_mode') @staticmethod def install(): print('Auto-installing ironclust.') return install_ironclust(commit='72bb6d097e0875d6cfe52bddf5f782e667e1b042') def run(self): timer = time.time() import spikesorters as sorters print('IronClust......') try: ironclust_path = IronClust.install() except: traceback.print_exc() raise Exception('Problem installing ironclust.') sorters.IronClustSorter.set_ironclust_path(ironclust_path) recording = SFMdaRecordingExtractor(self.recording_dir) code = ''.join(random.choice(string.ascii_uppercase) for x in range(10)) tmpdir = os.environ.get('TEMPDIR', '/tmp') + '/ironclust-' + code sorter = sorters.IronClustSorter( recording=recording, output_folder=tmpdir, debug=True, delete_output_folder = False # will be taken care by _keep_temp_files one step above ) sorter.set_params( detect_sign=self.detect_sign, adjacency_radius=self.adjacency_radius, adjacency_radius_out=self.adjacency_radius_out, detect_threshold=self.detect_threshold, prm_template_name=self.prm_template_name, freq_min=self.freq_min, freq_max=self.freq_max, merge_thresh=self.merge_thresh, pc_per_chan=self.pc_per_chan, whiten=self.whiten, filter_type=self.filter_type, filter_detect_type=self.filter_detect_type, common_ref_type=self.common_ref_type, batch_sec_drift=self.batch_sec_drift, step_sec_drift=self.step_sec_drift, knn=self.knn, min_count=self.min_count, fGpu=self.fGpu, fft_thresh=self.fft_thresh, fft_thresh_low=self.fft_thresh_low, nSites_whiten=self.nSites_whiten, feature_type=self.feature_type, delta_cut=self.delta_cut, post_merge_mode=self.post_merge_mode, sort_mode=self.sort_mode ) timer = sorter.run() print('#SF-SORTER-RUNTIME#{:.3f}#'.format(timer)) sorting = sorter.get_result() SFMdaSortingExtractor.write_sorting( sorting=sorting, save_path=self.firings_out)
class IronClust(mlpr.Processor): NAME = 'IronClust' VERSION = '0.2.0' ENVIRONMENT_VARIABLES = [ 'NUM_WORKERS', 'MKL_NUM_THREADS', 'NUMEXPR_NUM_THREADS', 'OMP_NUM_THREADS' ] ADDITIONAL_FILES = ['*.m'] CONTAINER = None CONTAINER_SHARE_ID = None recording_dir = mlpr.Input('Directory of recording', directory=True) channels = mlpr.IntegerListParameter( description='List of channels to use.', optional=True, default=[]) firings_out = mlpr.Output('Output firings file') detect_sign = mlpr.IntegerParameter( 'Use -1, 0, or 1, depending on the sign of the spikes in the recording' ) adjacency_radius = mlpr.FloatParameter( 'Use -1 to include all channels in every neighborhood') detect_threshold = mlpr.FloatParameter(optional=True, default=5, description='detection threshold') prm_template_name = mlpr.StringParameter( optional=True, description='.prm template file name') freq_min = mlpr.FloatParameter( optional=True, default=300, description='Use 0 for no bandpass filtering') freq_max = mlpr.FloatParameter( optional=True, default=6000, description='Use 0 for no bandpass filtering') merge_thresh = mlpr.FloatParameter(optional=True, default=0.98, description='TODO') pc_per_chan = mlpr.IntegerParameter(optional=True, default=3, description='TODO') def run(self): ironclust_path = os.environ.get('IRONCLUST_PATH', None) if not ironclust_path: raise Exception('Environment variable not set: IRONCLUST_PATH') code = ''.join( random.choice(string.ascii_uppercase) for x in range(10)) tmpdir = os.environ.get('TEMPDIR', '/tmp') + '/ironclust-tmp-' + code try: recording = se.MdaRecordingExtractor(self.recording_dir) params = read_dataset_params(self.recording_dir) if len(self.channels) > 0: recording = se.SubRecordingExtractor( parent_recording=recording, channel_ids=self.channels) if not os.path.exists(tmpdir): os.mkdir(tmpdir) sorting = ironclust_helper( recording=recording, tmpdir=tmpdir, detect_sign=self.detect_sign, adjacency_radius=self.adjacency_radius, detect_threshold=self.detect_threshold, merge_thresh=self.merge_thresh, freq_min=self.freq_min, freq_max=self.freq_max, pc_per_chan=self.pc_per_chan, prm_template_name=self.prm_template_name, ironclust_path=ironclust_path, params=params, ) se.MdaSortingExtractor.writeSorting(sorting=sorting, save_path=self.firings_out) except: if os.path.exists(tmpdir): if not getattr(self, '_keep_temp_files', False): shutil.rmtree(tmpdir) raise if not getattr(self, '_keep_temp_files', False): shutil.rmtree(tmpdir)