class CreateWaveformsPlot(mlpr.Processor): NAME='CreateWaveformsPlot' VERSION='0.1.0' recording_dir=mlpr.Input(directory=True,description='Recording directory') channels=mlpr.IntegerListParameter(description='List of channels to use.',optional=True,default=[]) units=mlpr.IntegerListParameter(description='List of units to use.',optional=True,default=[]) firings=mlpr.Input(description='Firings file') jpg_out=mlpr.Output('The plot as a .jpg file') def run(self): R0=si.MdaRecordingExtractor(dataset_directory=self.recording_dir,download=True) if len(self.channels)>0: R0=si.SubRecordingExtractor(parent_recording=R0,channel_ids=self.channels) R=sw.lazyfilters.bandpass_filter(recording=R0,freq_min=300,freq_max=6000) S=si.MdaSortingExtractor(firings_file=self.firings) channels=R.getChannelIds() if len(channels)>20: channels=channels[0:20] if len(self.units)>0: units=self.units else: units=S.getUnitIds() if len(units)>20: units=units[::int(len(units)/20)] sw.UnitWaveformsWidget(recording=R,sorting=S,channels=channels,unit_ids=units).plot() save_plot(self.jpg_out)
class ComputeUnitDetail(mlpr.Processor): NAME = 'ComputeUnitDetail' VERSION = '0.1.0' CONTAINER = None recording_dir = mlpr.Input(description='Recording directory', optional=False, directory=True) firings = mlpr.Input(description='Input firings.mda file') unit_id = mlpr.IntegerParameter(description='Unit ID') json_out = mlpr.Output(description='Output .json file') def run(self): recording = SFMdaRecordingExtractor( dataset_directory=self.recording_directory, download=True) sorting = SFMdaSortingExtractor(firings_file=self.firings) waveforms0 = _get_random_spike_waveforms(recording=recording, sorting=sorting, unit=self.unit_id) channel_ids = recording.get_channel_ids() avg_waveform = np.median(waveforms0, axis=2) ret = dict(channel_ids=channel_ids, average_waveform=avg_waveform.tolist()) with open(self.json_out, 'w') as f: json.dump(ret, f)
class GenSortingComparisonTable(mlpr.Processor): VERSION = '0.2.6' firings = mlpr.Input('Firings file (sorting)') firings_true = mlpr.Input('True firings file') units_true = mlpr.IntegerListParameter('List of true units to consider') json_out = mlpr.Output( 'Table as .json file produced from pandas dataframe') html_out = mlpr.Output( 'Table as .html file produced from pandas dataframe') # CONTAINER = 'sha1://5627c39b9bd729fc011cbfce6e8a7c37f8bcbc6b/spikeforest_basic.simg' # CONTAINER = 'sha1://0944f052e22de0f186bb6c5cb2814a71f118f2d1/spikeforest_basic.simg' # MAY26JJJ CONTAINER = 'sha1://4904b8f914eb159618b6579fb9ba07b269bb2c61/06-26-2019/spikeforest_basic.simg' def run(self): print( 'GenSortingComparisonTable: firings={}, firings_true={}, units_true={}' .format(self.firings, self.firings_true, self.units_true)) sorting = SFMdaSortingExtractor(firings_file=self.firings) sorting_true = SFMdaSortingExtractor(firings_file=self.firings_true) if (self.units_true is not None) and (len(self.units_true) > 0): sorting_true = si.SubSortingExtractor(parent_sorting=sorting_true, unit_ids=self.units_true) SC = SortingComparison(sorting_true, sorting, delta_tp=30) df = get_comparison_data_frame(comparison=SC) # sw.SortingComparisonTable(comparison=SC).getDataframe() json = df.transpose().to_dict() html = df.to_html(index=False) _write_json_file(json, self.json_out) _write_json_file(html, self.html_out)
class ComputeUnitsInfo(mlpr.Processor): NAME = 'ComputeUnitsInfo' VERSION = '0.1.1' recording_dir = mlpr.Input(directory=True, description='Recording directory') channel_ids = mlpr.IntegerListParameter( description='List of channels to use.', optional=True, default=[]) unit_ids = mlpr.IntegerListParameter(description='List of units to use.', optional=True, default=[]) firings = mlpr.Input(description='Firings file') json_out = mlpr.Output('The info as a .json file') def run(self): R0 = si.MdaRecordingExtractor(dataset_directory=self.recording_dir, download=True) if (self.channel_ids) and (len(self.channel_ids) > 0): R0 = si.SubRecordingExtractor(parent_recording=R0, channel_ids=self.channel_ids) recording = sw.lazyfilters.bandpass_filter(recording=R0, freq_min=300, freq_max=6000) sorting = si.MdaSortingExtractor(firings_file=self.firings) ef = int(1e6) recording_sub = si.SubRecordingExtractor(parent_recording=recording, start_frame=0, end_frame=ef) recording_sub = MemoryRecordingExtractor( parent_recording=recording_sub) sorting_sub = si.SubSortingExtractor(parent_sorting=sorting, start_frame=0, end_frame=ef) unit_ids = self.unit_ids if (not unit_ids) or (len(unit_ids) == 0): unit_ids = sorting.getUnitIds() channel_noise_levels = compute_channel_noise_levels( recording=recording) print('computing templates...') templates = compute_unit_templates(recording=recording_sub, sorting=sorting_sub, unit_ids=unit_ids) print('.') ret = [] for i, unit_id in enumerate(unit_ids): template = templates[i] info0 = dict() info0['unit_id'] = int(unit_id) info0['snr'] = compute_template_snr(template, channel_noise_levels) peak_channel_index = np.argmax(np.max(np.abs(template), axis=1)) info0['peak_channel'] = int( recording.getChannelIds()[peak_channel_index]) train = sorting.getUnitSpikeTrain(unit_id=unit_id) info0['num_events'] = int(len(train)) info0['firing_rate'] = float( len(train) / (recording.getNumFrames() / recording.getSamplingFrequency())) ret.append(info0) write_json_file(self.json_out, ret)
class ComputeUnitsInfo(mlpr.Processor): NAME = 'ComputeUnitsInfo' VERSION = '0.1.3' recording = mlpr.Input() sorting = mlpr.Input() json_out = mlpr.Output() def run(self): info0 = sa.compute_units_info(recording=self.recording, sorting=self.sorting) with open(self.json_out, 'w') as f: json.dump(info0, f)
class ComputeUnitsInfo(mlpr.Processor): NAME='ComputeUnitsInfo' VERSION='0.1.5k' CONTAINER=_CONTAINER recording_dir=mlpr.Input(directory=True,description='Recording directory') channel_ids=mlpr.IntegerListParameter(description='List of channels to use.',optional=True,default=[]) unit_ids=mlpr.IntegerListParameter(description='List of units to use.',optional=True,default=[]) firings=mlpr.Input(description='Firings file') json_out=mlpr.Output('The info as a .json file') def run(self): import spikewidgets as sw R0=si.MdaRecordingExtractor(dataset_directory=self.recording_dir,download=True) if (self.channel_ids) and (len(self.channel_ids)>0): R0=si.SubRecordingExtractor(parent_recording=R0,channel_ids=self.channel_ids) recording = R0 # recording=sw.lazyfilters.bandpass_filter(recording=R0,freq_min=300,freq_max=6000) sorting=si.MdaSortingExtractor(firings_file=self.firings) unit_ids=self.unit_ids if (not unit_ids) or (len(unit_ids)==0): unit_ids=sorting.getUnitIds() channel_noise_levels=compute_channel_noise_levels(recording=recording) # No longer use subset to compute the templates templates=compute_unit_templates(recording=recording,sorting=sorting,unit_ids=unit_ids,max_num=100) ret=[] for i,unit_id in enumerate(unit_ids): template=templates[i] max_p2p_amps_on_channels=np.max(template,axis=1)-np.min(template,axis=1) peak_channel_index=np.argmax(max_p2p_amps_on_channels) peak_channel=recording.getChannelIds()[peak_channel_index] R1=si.SubRecordingExtractor(parent_recording=recording,channel_ids=[peak_channel_index]) R1f=sw.lazyfilters.bandpass_filter(recording=R1,freq_min=300,freq_max=6000) templates2=compute_unit_templates(recording=R1f,sorting=sorting,unit_ids=[unit_id],max_num=100) template2=templates2[0] info0=dict() info0['unit_id']=int(unit_id) info0['snr']=np.max(np.abs(template2))/channel_noise_levels[peak_channel_index] #info0['snr']=compute_template_snr(template,channel_noise_levels) #peak_channel_index=np.argmax(np.max(np.abs(template),axis=1)) info0['peak_channel']=int(recording.getChannelIds()[peak_channel]) train=sorting.getUnitSpikeTrain(unit_id=unit_id) info0['num_events']=int(len(train)) info0['firing_rate']=float(len(train)/(recording.getNumFrames()/recording.getSamplingFrequency())) ret.append(info0) write_json_file(self.json_out,ret)
class PlotUnitWaveforms(mlpr.Processor): VERSION='0.1.0' recording_dir=mlpr.Input(directory=True,description='Recording directory') channels=mlpr.IntegerListParameter(description='List of channels to use.',optional=True,default=[]) firings=mlpr.Input('Firings file (sorting)') plot_out=mlpr.Output('Plot as .jpg image file') def run(self): recording=si.MdaRecordingExtractor(dataset_directory=self.recording_dir) if len(self.channels)>0: recording=si.SubRecordingExtractor(parent_recording=recording,channel_ids=self.channels) sorting=si.MdaSortingExtractor(firings_file=self.firings) sw.UnitWaveformsWidget(recording=recording,sorting=sorting).plot() fname=save_plot(self.plot_out)
class ComputeUnitTemplates(mlpr.Processor): NAME = 'ComputeUnitTemplates' VERSION = '0.1.4' recording = mlpr.Input() sorting = mlpr.Input() templates_out = mlpr.OutputArray() def run(self): templates = compute_unit_templates( recording=self.recording, sorting=self.sorting, unit_ids=self.sorting.get_unit_ids()) # pylint: disable=no-member print('Saving templates...', self.templates_out) np.save(self.templates_out, templates)
class PlotAutoCorrelograms(mlpr.Processor): NAME='spikeforest.PlotAutoCorrelograms' VERSION='0.1.0' recording_dir=mlpr.Input(directory=True,description='Recording directory') channels=mlpr.IntegerListParameter(description='List of channels to use.',optional=True,default=[]) firings=mlpr.Input('Firings file (sorting)') plot_out=mlpr.Output('Plot as .jpg image file') def run(self): recording=si.MdaRecordingExtractor(dataset_directory=self.recording_dir,download=False) if len(self.channels)>0: recording=si.SubRecordingExtractor(parent_recording=recording,channel_ids=self.channels) sorting=si.MdaSortingExtractor(firings_file=self.firings) sw.CrossCorrelogramsWidget(samplerate=recording.getSamplingFrequency(),sorting=sorting).plot() fname=save_plot(self.plot_out)
class CreateTimeseriesPlot(mlpr.Processor): NAME = 'CreateTimeseriesPlot' VERSION = '0.1.7' recording_dir = mlpr.Input(directory=True, description='Recording directory') channels = mlpr.IntegerListParameter( description='List of channels to use.', optional=True, default=[]) jpg_out = mlpr.Output('The plot as a .jpg file') def run(self): R0 = si.MdaRecordingExtractor(dataset_directory=self.recording_dir, download=False) if len(self.channels) > 0: R0 = si.SubRecordingExtractor(parent_recording=R0, channel_ids=self.channels) R = sw.lazyfilters.bandpass_filter(recording=R0, freq_min=300, freq_max=6000) N = R.getNumFrames() N2 = int(N / 2) channels = R.getChannelIds() if len(channels) > 20: channels = channels[0:20] sw.TimeseriesWidget(recording=R, trange=[N2 - 4000, N2 + 0], channels=channels, width=12, height=5).plot() save_plot(self.jpg_out)
class ExtractTwoPhotonSeriesMp4(mlpr.Processor): NAME = 'H5ToDict' VERSION = '0.1.1' # Inputs nwb_in = mlpr.Input() # Outputs mp4_out = mlpr.Output() def run(self): nwb_obj = nwb_to_dict(self.nwb_in, use_cache=True) npy_path = nwb_obj['acquisition']['TwoPhotonSeries']['_datasets']['data']['_data'] npy_path2 = mt.realizeFile(npy_path) if not npy_path2: nwb_obj = nwb_to_dict(self.nwb_in, use_cache=False) npy_path = nwb_obj['acquisition']['TwoPhotonSeries']['_datasets']['data']['_data'] npy_path2 = mt.realizeFile(npy_path) if not npy_path2: self._set_error('Unable to realize npy file: {}'.format(npy_path)) return X = np.load(npy_path2) # Note that there is a bug in imageio.mimwrite that prevents us to # write to a memory buffer. # See: https://github.com/imageio/imageio/issues/157 imageio.mimwrite(self.mp4_out, X, format='mp4', fps=10)
class KiloSort(mlpr.Processor): NAME = 'KiloSort' VERSION = '0.2.0' # wrapper VERSION ADDITIONAL_FILES = ['*.m'] ENVIRONMENT_VARIABLES = [ 'NUM_WORKERS', 'MKL_NUM_THREADS', 'NUMEXPR_NUM_THREADS', 'OMP_NUM_THREADS'] CONTAINER = None CONTAINER_SHARE_ID = None recording_dir = mlpr.Input('Directory of recording', directory=True) channels = mlpr.IntegerListParameter( description='List of channels to use.', optional=True, default=[]) firings_out = mlpr.Output('Output firings file') detect_sign = mlpr.IntegerParameter( 'Use -1 or 1, depending on the sign of the spikes in the recording') adjacency_radius = mlpr.FloatParameter( 'Use -1 to include all channels in every neighborhood') detect_threshold = mlpr.FloatParameter( optional=True, default=3, description='') # prm_template_name=mlpr.StringParameter(optional=False,description='TODO') freq_min = mlpr.FloatParameter( optional=True, default=300, description='Use 0 for no bandpass filtering') freq_max = mlpr.FloatParameter( optional=True, default=6000, description='Use 0 for no bandpass filtering') merge_thresh = mlpr.FloatParameter( optional=True, default=0.98, description='TODO') pc_per_chan = mlpr.IntegerParameter( optional=True, default=3, description='TODO') def run(self): code = ''.join(random.choice(string.ascii_uppercase) for x in range(10)) tmpdir = os.environ.get('TEMPDIR', '/tmp')+'/kilosort-tmp-'+code try: recording = se.MdaRecordingExtractor(self.recording_dir) if len(self.channels) > 0: recording = se.SubRecordingExtractor( parent_recording=recording, channel_ids=self.channels) if not os.path.exists(tmpdir): os.mkdir(tmpdir) sorting = kilosort_helper( recording=recording, tmpdir=tmpdir, detect_sign=self.detect_sign, adjacency_radius=self.adjacency_radius, detect_threshold=self.detect_threshold, merge_thresh=self.merge_thresh, freq_min=self.freq_min, freq_max=self.freq_max, pc_per_chan=self.pc_per_chan ) se.MdaSortingExtractor.writeSorting( sorting=sorting, save_path=self.firings_out) except: if os.path.exists(tmpdir): shutil.rmtree(tmpdir) raise shutil.rmtree(tmpdir)
class GenSortingComparisonTableNew(mlpr.Processor): VERSION = '0.3.1' firings = mlpr.Input('Firings file (sorting)') firings_true = mlpr.Input('True firings file') units_true = mlpr.IntegerListParameter('List of true units to consider') json_out = mlpr.Output( 'Table as .json file produced from pandas dataframe') html_out = mlpr.Output( 'Table as .html file produced from pandas dataframe') # CONTAINER = 'sha1://5627c39b9bd729fc011cbfce6e8a7c37f8bcbc6b/spikeforest_basic.simg' # CONTAINER = 'sha1://0944f052e22de0f186bb6c5cb2814a71f118f2d1/spikeforest_basic.simg' # MAY26JJJ CONTAINER = 'sha1://4904b8f914eb159618b6579fb9ba07b269bb2c61/06-26-2019/spikeforest_basic.simg' def run(self): print( 'GenSortingComparisonTable: firings={}, firings_true={}, units_true={}' .format(self.firings, self.firings_true, self.units_true)) sorting = SFMdaSortingExtractor(firings_file=self.firings) sorting_true = SFMdaSortingExtractor(firings_file=self.firings_true) if (self.units_true is not None) and (len(self.units_true) > 0): sorting_true = si.SubSortingExtractor(parent_sorting=sorting_true, unit_ids=self.units_true) SC = st.comparison.compare_sorter_to_ground_truth( gt_sorting=sorting_true, tested_sorting=sorting, delta_time=0.3, min_accuracy=0, compute_misclassification=False, exhaustive_gt=False # Fix this in future ) df = pd.concat([SC.count, SC.get_performance()], axis=1).reset_index() df = df.rename(columns=dict(gt_unit_id='unit_id', fp='num_false_positives', fn='num_false_negatives', tested_id='best_unit', tp='num_matches')) df['matched_unit'] = df['best_unit'] df['f_p'] = 1 - df['precision'] df['f_n'] = 1 - df['recall'] # sw.SortingComparisonTable(comparison=SC).getDataframe() json = df.transpose().to_dict() html = df.to_html(index=False) _write_json_file(json, self.json_out) _write_json_file(html, self.html_out)
class MountainSort4(mlpr.Processor): NAME = 'MountainSort4' VERSION = '4.0.1' dataset_dir = mlpr.Input('Directory of dataset', directory=True) firings_out = mlpr.Output('Output firings file') detect_sign = mlpr.IntegerParameter( 'Use -1, 0, or 1, depending on the sign of the spikes in the recording' ) adjacency_radius = mlpr.FloatParameter( 'Use -1 to include all channels in every neighborhood') freq_min = mlpr.FloatParameter( optional=True, default=300, description='Use 0 for no bandpass filtering') freq_max = mlpr.FloatParameter( optional=True, default=6000, description='Use 0 for no bandpass filtering') whiten = mlpr.BoolParameter( optional=True, default=True, description='Whether to do channel whitening as part of preprocessing') clip_size = mlpr.IntegerParameter(optional=True, default=50, description='') detect_threshold = mlpr.FloatParameter(optional=True, default=3, description='') detect_interval = mlpr.IntegerParameter( optional=True, default=10, description= 'Minimum number of timepoints between events detected on the same channel' ) noise_overlap_threshold = mlpr.FloatParameter( optional=True, default=0.15, description='Use None for no automated curation') def run(self): recording = si.MdaRecordingExtractor(self.dataset_dir) num_workers = int(os.environ.get('NUM_WORKERS', -1)) if num_workers <= 0: num_workers = None sorting = sf.sorters.mountainsort4( recording=recording, detect_sign=self.detect_sign, adjacency_radius=self.adjacency_radius, freq_min=self.freq_min, freq_max=self.freq_max, whiten=self.whiten, clip_size=self.clip_size, detect_threshold=self.detect_threshold, detect_interval=self.detect_interval, noise_overlap_threshold=self.noise_overlap_threshold, num_workers=num_workers) si.MdaSortingExtractor.writeSorting(sorting=sorting, save_path=self.firings_out)
class CreateEfficientAccessRecordingFile(mlpr.Processor): NAME = 'CreateEfficientAccessRecordingFile' VERSION = '0.1.2' recording = mlpr.Input() segment_size = mlpr.IntegerParameter(optional=True, default=1000000) hdf5_out = mlpr.Output() def run(self): import h5py recording = self.recording segment_size = self.segment_size channel_ids = recording.get_channel_ids() samplerate = recording.get_sampling_frequency() M = len(channel_ids) # number of channels N = recording.get_num_frames() # Number of timepoints num_segments = int(np.ceil(N / segment_size)) try: channel_locations = recording.get_channel_locations( channel_ids=channel_ids) nd = len(channel_locations[0]) geom = np.zeros((M, nd)) for m in range(M): geom[m, :] = channel_locations[m] except: nd = 2 geom = np.zeros((M, nd)) with h5py.File(self.hdf5_out, "w") as f: f.create_dataset('segment_size', data=[segment_size]) f.create_dataset('num_segments', data=[num_segments]) f.create_dataset('num_channels', data=[M]) f.create_dataset('channel_ids', data=np.array(channel_ids)) f.create_dataset('num_timepoints', data=[N]) f.create_dataset('samplerate', data=[samplerate]) f.create_dataset('geom', data=geom) if callable(recording.hash): hash0 = recording.hash() else: hash0 = recording.hash f.create_dataset('recording_hash', data=np.array([hash0.encode()])) for j in range(num_segments): segment = np.zeros((M, segment_size), dtype=float) # fix dtype here t1 = int(j * segment_size) # first timepoint of the segment t2 = int(np.minimum( N, (t1 + segment_size))) # last timepoint of segment (+1) s1 = int(np.maximum(0, t1)) # first timepoint s2 = int(np.minimum(N, t2)) # last timepoint (+1) # determine aa so that t1-s1+aa = 0 # so, aa = -(t1-s1) aa = -(t1 - s1) segment[:, aa:aa + s2 - s1] = recording.get_traces( start_frame=s1, end_frame=s2) # Read the segment for ii, ch in enumerate(channel_ids): f.create_dataset('part-{}-{}'.format(ch, j), data=segment[ii, :].ravel())
class Waveclus(mlpr.Processor): """ Wave_clus wrapper written by J. James Jun, May 21, 2019 [Optional: Installation instruction in SpikeForest environment] 1. Run `git clone https://github.com/csn-le/wave_clus.git` 2. Activate conda environment for SpikeForest 3. Create `WAVECLUS_PATH_DEV` Algorithm website: https://github.com/csn-le/wave_clus/wiki """ NAME = 'waveclus' VERSION = '0.0.5' ENVIRONMENT_VARIABLES = [ 'NUM_WORKERS', 'MKL_NUM_THREADS', 'NUMEXPR_NUM_THREADS', 'OMP_NUM_THREADS', 'TEMPDIR' ] ADDITIONAL_FILES = ['*.m', '*.prm'] CONTAINER = None LOCAL_MODULES = ['../../spikeforest'] recording_dir = mlpr.Input('Directory of recording', directory=True) firings_out = mlpr.Output('Output firings file') def run(self): tmpdir = _get_tmpdir('waveclus') try: recording = SFMdaRecordingExtractor(self.recording_dir) params = read_dataset_params(self.recording_dir) # if len(self.channels) > 0: # recording = se.SubRecordingExtractor( # parent_recording=recording, channel_ids=self.channels) if not os.path.exists(tmpdir): os.mkdir(tmpdir) all_params = dict() for param0 in self.PARAMETERS: all_params[param0.name] = getattr(self, param0.name) sorting = waveclus_helper( recording=recording, tmpdir=tmpdir, params=params, **all_params, ) SFMdaSortingExtractor.write_sorting(sorting=sorting, save_path=self.firings_out) except: if os.path.exists(tmpdir): if not getattr(self, '_keep_temp_files', False): print('erased temp file 1') shutil.rmtree(tmpdir) raise if not getattr(self, '_keep_temp_files', False): print('erased temp file 2') shutil.rmtree(tmpdir)
class GenSortingComparisonTable(mlpr.Processor): VERSION='0.1.1' firings=mlpr.Input('Firings file (sorting)') firings_true=mlpr.Input('True firings file') units_true=mlpr.IntegerListParameter('List of true units to consider') json_out=mlpr.Output('Table as .json file produced from pandas dataframe') html_out=mlpr.Output('Table as .html file produced from pandas dataframe') def run(self): sorting=si.MdaSortingExtractor(firings_file=self.firings) sorting_true=si.MdaSortingExtractor(firings_file=self.firings_true) if len(self.units_true)>0: sorting_true=si.SubSortingExtractor(parent_sorting=sorting_true,unit_ids=self.units_true) SC=st.comparison.SortingComparison(sorting_true,sorting) df=sw.SortingComparisonTable(comparison=SC).getDataframe() json=df.transpose().to_dict() html=df.to_html(index=False) _write_json_file(json,self.json_out) _write_json_file(html,self.html_out)
class YASS(mlpr.Processor): NAME = 'YASS' VERSION = '0.1.0' # used by container to pass the env variables ENVIRONMENT_VARIABLES = [ 'NUM_WORKERS', 'MKL_NUM_THREADS', 'NUMEXPR_NUM_THREADS', 'OMP_NUM_THREADS'] ADDITIONAL_FILES = ['*.yaml'] CONTAINER = 'sha1://087767605e10761331699dda29519444bbd823f4/02-12-2019/yass.simg' CONTAINER_SHARE_ID = '69432e9201d0' # place to look for container recording_dir = mlpr.Input('Directory of recording', directory=True) channels = mlpr.IntegerListParameter( description='List of channels to use.', optional=True, default=[]) firings_out = mlpr.Output('Output firings file') #paramfile_out = mlpr.Output('YASS yaml config file') detect_sign = mlpr.IntegerParameter(description='-1, 1, or 0') adjacency_radius = mlpr.FloatParameter( optional=True, default=100, description='Channel neighborhood adjacency radius corresponding to geom file') template_width_ms = mlpr.FloatParameter( optional=True, default=3, description='Spike width in milliseconds') filter = mlpr.BoolParameter(optional=True, default=True) def run(self): code = ''.join(random.choice(string.ascii_uppercase) for x in range(10)) tmpdir = os.environ.get('TEMPDIR', '/tmp')+'/yass-tmp-'+code #num_workers = os.environ.get('NUM_WORKERS', 1) #print('num_workers: {}'.format(num_workers)) try: recording = se.MdaRecordingExtractor(self.recording_dir) if len(self.channels) > 0: recording = se.SubRecordingExtractor( parent_recording=recording, channel_ids=self.channels) if not os.path.exists(tmpdir): os.mkdir(tmpdir) sorting, yaml_file = yass_helper( recording=recording, output_folder=tmpdir, probe_file=None, file_name=None, detect_sign=self.detect_sign, adjacency_radius=self.adjacency_radius, template_width_ms=self.template_width_ms, filter=self.filter) se.MdaSortingExtractor.writeSorting( sorting=sorting, save_path=self.firings_out) #shutil.copyfile(yaml_file, self.paramfile_out) except: if os.path.exists(tmpdir): shutil.rmtree(tmpdir) raise if not getattr(self, '_keep_temp_files', False): shutil.rmtree(tmpdir)
class GenerateMearecRecording(mlpr.Processor): NAME = "GenerateMearecRecording" VERSION = "0.1.0" # input file templates_in = mlpr.Input(description='.h5 file containing templates') # output file recording_out = mlpr.Output() # recordings params drifting = mlpr.BoolParameter() noise_level = mlpr.FloatParameter() bursting = mlpr.BoolParameter() shape_mod = mlpr.BoolParameter() # spiketrains params duration = mlpr.FloatParameter() n_exc = mlpr.IntegerParameter() n_inh = mlpr.IntegerParameter() # templates params min_dist = mlpr.FloatParameter() # seed seed = mlpr.IntegerParameter() def run(self): recordings_params = deepcopy(mr.get_default_recordings_params()) recordings_params['recordings']['drifting'] = self.drifting recordings_params['recordings']['noise_level'] = self.noise_level recordings_params['recordings']['bursting'] = self.bursting recordings_params['recordings']['shape_mod'] = self.shape_mod recordings_params['recordings']['seed'] = self.seed # recordings_params['recordings']['chunk_conv_duration'] = 0 # turn off parallel execution recordings_params['spiketrains']['duration'] = self.duration recordings_params['spiketrains']['n_exc'] = self.n_exc recordings_params['spiketrains']['n_inh'] = self.n_inh recordings_params['spiketrains']['seed'] = self.seed recordings_params['templates']['min_dist'] = self.min_dist recordings_params['templates']['seed'] = self.seed # this is needed because mr.load_templates requires the file extension templates_fname = self.templates_in + '.h5' shutil.copyfile(self.templates_in, templates_fname) tempgen = mr.load_templates(Path(templates_fname)) recgen = mr.gen_recordings(params=recordings_params, tempgen=tempgen, verbose=False) mr.save_recording_generator(recgen, self.recording_out) del recgen
class CreateSpikeSprays(mlpr.Processor): NAME = 'CreateSpikeSprays' VERSION = '0.1.0' CONTAINER = _CONTAINER recording_directory = mlpr.Input(description='Recording directory', optional=False, directory=True) filtered_timeseries = mlpr.Input( description='Filtered timeseries file (.mda)', optional=False) firings_true = mlpr.Input(description='True firings -- firings_true.mda', optional=False) firings_sorted = mlpr.Input(description='Sorted firings -- firings.mda', optional=False) unit_id_true = mlpr.IntegerParameter(description='ID of the true unit') unit_id_sorted = mlpr.IntegerParameter(description='ID of the sorted unit') neighborhood_size = mlpr.IntegerParameter( description='Max size of the electrode neighborhood', optional=True, default=7) num_spikes = mlpr.IntegerParameter( description='Max number of spikes in the spike spray', optional=True, default=20) json_out = mlpr.Output(description='Output json object') def run(self): rx = SFMdaRecordingExtractor( dataset_directory=self.recording_directory, download=True, raw_fname=self.filtered_timeseries) sx_true = SFMdaSortingExtractor(firings_file=self.firings_true) sx = SFMdaSortingExtractor(firings_file=self.firings_sorted) ssobj = create_spikesprays(rx=rx, sx_true=sx_true, sx_sorted=sx, neighborhood_size=self.neighborhood_size, num_spikes=self.num_spikes, unit_id_true=self.unit_id_true, unit_id_sorted=self.unit_id_sorted) with open(self.json_out, 'w') as f: json.dump(ssobj, f)
class ComputeUnitDetail(mlpr.Processor): NAME = 'ComputeUnitsDetail' VERSION = '0.1.5' recording = mlpr.Input() sorting = mlpr.Input() unit_id = mlpr.IntegerParameter() output = mlpr.Output() def run(self): event_times = self.sorting.get_unit_spike_train(unit_id=self.unit_id) # pylint: disable=no-member snippets = self.recording.get_snippets(reference_frames=event_times, snippet_len=100) # pylint: disable=no-member template = np.median(np.stack(snippets), axis=0) result0 = dict(unit_id=self.unit_id, num_events=len(event_times), event_times=event_times, snippets=snippets, template=template) with open(self.output, 'wb') as f: pickle.dump(result0, f)
class AffinityPropagation(mlpr.Processor): VERSION = '0.1.0' data = mlpr.Input() labels_out = mlpr.Output(is_array=True) damping = mlpr.FloatParameter() def run(self): from sklearn.cluster import AffinityPropagation import numpy as np A = AffinityPropagation(damping=self.damping).fit(np.load(self.data)) np.save(self.labels_out + '.npy', A.labels_) os.rename(self.labels_out + '.npy', self.labels_out)
class ComputeDatasetInfo(mlpr.Processor): NAME='ComputeDatasetInfo' VERSION='0.1.0' recording_dir=mlpr.Input(directory=True,description='Recording directory') json_out=mlpr.Output('Info in .json file') def run(self): ret={} recording=si.MdaRecordingExtractor(dataset_directory=self.recording_dir,download=False) ret['samplerate']=recording.getSamplingFrequency() ret['num_channels']=len(recording.getChannelIds()) ret['duration_sec']=recording.getNumFrames()/ret['samplerate'] write_json_file(self.json_out,ret)
class ComputeUnitsInfo(mlpr.Processor): NAME = 'ComputeUnitsInfo' VERSION = '0.1.8' CONTAINER = _CONTAINER recording_dir = mlpr.Input(directory=True, description='Recording directory') channel_ids = mlpr.IntegerListParameter( description='List of channels to use.', optional=True, default=[]) unit_ids = mlpr.IntegerListParameter(description='List of units to use.', optional=True, default=[]) firings = mlpr.Input(description='Firings file') json_out = mlpr.Output('The info as a .json file') def run(self): R0 = SFMdaRecordingExtractor(dataset_directory=self.recording_dir, download=True) sorting = SFMdaSortingExtractor(firings_file=self.firings) ret = compute_units_info(recording=R0, sorting=sorting, channel_ids=self.channel_ids, unit_ids=self.unit_ids) write_json_file(self.json_out, ret)
class SpykingCircus(mlpr.Processor): NAME='SpykingCircus' VERSION='0.1.2' recording_dir=mlpr.Input('Directory of recording',directory=True) channels=mlpr.IntegerListParameter(description='List of channels to use.',optional=True,default=[]) firings_out=mlpr.Output('Output firings file') detect_sign=mlpr.IntegerParameter(description='-1, 1, or 0') adjacency_radius=mlpr.FloatParameter(optional=True,default=100,description='Channel neighborhood adjacency radius corresponding to geom file') spike_thresh=mlpr.FloatParameter(optional=True,default=6,description='Threshold for detection') template_width_ms=mlpr.FloatParameter(optional=True,default=3,description='Spyking circus parameter') filter=mlpr.BoolParameter(optional=True,default=True) whitening_max_elts=mlpr.IntegerParameter(optional=True,default=1000,description='I believe it relates to subsampling and affects compute time') clustering_max_elts=mlpr.IntegerParameter(optional=True,default=10000,description='I believe it relates to subsampling and affects compute time') def run(self): code=''.join(random.choice(string.ascii_uppercase) for x in range(10)) tmpdir=os.environ.get('TEMPDIR','/tmp')+'/ironclust-tmp-'+code num_workers=os.environ.get('NUM_WORKERS',2) try: recording=si.MdaRecordingExtractor(self.recording_dir) if len(self.channels)>0: recording=si.SubRecordingExtractor(parent_recording=recording,channel_ids=self.channels) if not os.path.exists(tmpdir): os.mkdir(tmpdir) sorting=sf.sorters.spyking_circus( recording=recording, output_folder=tmpdir, probe_file=None, file_name=None, detect_sign=self.detect_sign, adjacency_radius=self.adjacency_radius, spike_thresh=self.spike_thresh, template_width_ms=self.template_width_ms, filter=self.filter, merge_spikes=True, n_cores=num_workers, electrode_dimensions=None, whitening_max_elts=self.whitening_max_elts, clustering_max_elts=self.clustering_max_elts ) si.MdaSortingExtractor.writeSorting(sorting=sorting,save_path=self.firings_out) except: if os.path.exists(tmpdir): shutil.rmtree(tmpdir) raise shutil.rmtree(tmpdir)
class IronClust(mlpr.Processor): NAME='IronClust' VERSION='4.2.6' recording_dir=mlpr.Input('Directory of recording',directory=True) channels=mlpr.IntegerListParameter(description='List of channels to use.',optional=True,default=[]) firings_out=mlpr.Output('Output firings file') detect_sign=mlpr.IntegerParameter('Use -1, 0, or 1, depending on the sign of the spikes in the recording') adjacency_radius=mlpr.FloatParameter('Use -1 to include all channels in every neighborhood') detect_threshold=mlpr.FloatParameter(optional=True,default=3,description='') prm_template_name=mlpr.StringParameter(optional=False,description='TODO') freq_min=mlpr.FloatParameter(optional=True,default=300,description='Use 0 for no bandpass filtering') freq_max=mlpr.FloatParameter(optional=True,default=6000,description='Use 0 for no bandpass filtering') merge_thresh=mlpr.FloatParameter(optional=True,default=0.98,description='TODO') pc_per_chan=mlpr.IntegerParameter(optional=True,default=3,description='TODO') def run(self): ironclust_src=os.environ.get('IRONCLUST_SRC',None) if not ironclust_src: raise Exception('Environment variable not set: IRONCLUST_SRC') code=''.join(random.choice(string.ascii_uppercase) for x in range(10)) tmpdir=os.environ.get('TEMPDIR','/tmp')+'/ironclust-tmp-'+code try: recording=si.MdaRecordingExtractor(self.recording_dir) if len(self.channels)>0: recording=si.SubRecordingExtractor(parent_recording=recording,channel_ids=self.channels) if not os.path.exists(tmpdir): os.mkdir(tmpdir) sorting=sf.sorters.ironclust( recording=recording, tmpdir=tmpdir, ## TODO detect_sign=self.detect_sign, adjacency_radius=self.adjacency_radius, detect_threshold=self.detect_threshold, merge_thresh=self.merge_thresh, freq_min=self.freq_min, freq_max=self.freq_max, pc_per_chan=self.pc_per_chan, prm_template_name=self.prm_template_name, ironclust_src=ironclust_src ) si.MdaSortingExtractor.writeSorting(sorting=sorting,save_path=self.firings_out) except: if os.path.exists(tmpdir): shutil.rmtree(tmpdir) raise shutil.rmtree(tmpdir)
class RepeatText(mlpr.Processor): textfile = mlpr.Input(help="input text file") textfile_out = mlpr.Output(help="output text file") num_repeats = mlpr.IntegerListParameter( help="Number of times to repeat the text") def run(self): assert self.num_repeats >= 0 with open(self.textfile, 'r') as f: txt = f.read() txt2 = '' for _ in range(self.num_repeats): txt2 = txt2 + txt with open(self.textfile_out, 'w') as f: f.write(txt2)
class ComputeRecordingInfo(mlpr.Processor): NAME='ComputeRecordingInfo' VERSION='0.1.0' recording_dir=mlpr.Input(directory=True,description='Recording directory') channels=mlpr.IntegerListParameter(description='List of channels to use.',optional=True,default=[]) json_out=mlpr.Output('Info in .json file') def run(self): ret={} recording=si.MdaRecordingExtractor(dataset_directory=self.recording_dir,download=False) if len(self.channels)>0: recording=si.SubRecordingExtractor(parent_recording=recording,channel_ids=self.channels) ret['samplerate']=recording.getSamplingFrequency() ret['num_channels']=len(recording.getChannelIds()) ret['duration_sec']=recording.getNumFrames()/ret['samplerate'] write_json_file(self.json_out,ret)
class MeanShift(mlpr.Processor): VERSION = '0.1.0' data = mlpr.Input() labels_out = mlpr.Output(is_array=True) bandwidth = mlpr.StringParameter() def run(self): from sklearn.cluster import MeanShift import numpy as np if self.bandwidth == 'auto': bandwidth = None else: bandwidth = float(self.bandwidth) A = MeanShift(bandwidth=bandwidth).fit(np.load(self.data)) np.save(self.labels_out + '.npy', A.labels_) os.rename(self.labels_out + '.npy', self.labels_out)
class CombineSubsampledMandelbrot(mlpr.Processor): NAME = 'CombineSubsampledMandelbrot' VERSION = '0.1.1' X_list = mlpr.Input(multi=True) X_out = mlpr.Output() num_x = mlpr.IntegerParameter() def run(self): print('=== CombineSubsampledMandelbrot ===', self.num_x) self.X_list arrays = [] for X0 in self.X_list: arrays.append(np.load(X0)) X = combine_subsampled_mandelbrot(arrays) X = X[:self.num_x, :] np.save(self.X_out, X)