class CreateWaveformsPlot(mlpr.Processor):
  NAME='CreateWaveformsPlot'
  VERSION='0.1.0'
  recording_dir=mlpr.Input(directory=True,description='Recording directory')
  channels=mlpr.IntegerListParameter(description='List of channels to use.',optional=True,default=[])
  units=mlpr.IntegerListParameter(description='List of units to use.',optional=True,default=[])
  firings=mlpr.Input(description='Firings file')
  jpg_out=mlpr.Output('The plot as a .jpg file')
  
  def run(self):
    R0=si.MdaRecordingExtractor(dataset_directory=self.recording_dir,download=True)
    if len(self.channels)>0:
      R0=si.SubRecordingExtractor(parent_recording=R0,channel_ids=self.channels)
    R=sw.lazyfilters.bandpass_filter(recording=R0,freq_min=300,freq_max=6000)
    S=si.MdaSortingExtractor(firings_file=self.firings)
    channels=R.getChannelIds()
    if len(channels)>20:
      channels=channels[0:20]
    if len(self.units)>0:
      units=self.units
    else:
      units=S.getUnitIds()
    if len(units)>20:
      units=units[::int(len(units)/20)]
    sw.UnitWaveformsWidget(recording=R,sorting=S,channels=channels,unit_ids=units).plot()
    save_plot(self.jpg_out)
Beispiel #2
0
class ComputeUnitsInfo(mlpr.Processor):
    NAME = 'ComputeUnitsInfo'
    VERSION = '0.1.1'
    recording_dir = mlpr.Input(directory=True,
                               description='Recording directory')
    channel_ids = mlpr.IntegerListParameter(
        description='List of channels to use.', optional=True, default=[])
    unit_ids = mlpr.IntegerListParameter(description='List of units to use.',
                                         optional=True,
                                         default=[])
    firings = mlpr.Input(description='Firings file')
    json_out = mlpr.Output('The info as a .json file')

    def run(self):
        R0 = si.MdaRecordingExtractor(dataset_directory=self.recording_dir,
                                      download=True)
        if (self.channel_ids) and (len(self.channel_ids) > 0):
            R0 = si.SubRecordingExtractor(parent_recording=R0,
                                          channel_ids=self.channel_ids)
        recording = sw.lazyfilters.bandpass_filter(recording=R0,
                                                   freq_min=300,
                                                   freq_max=6000)
        sorting = si.MdaSortingExtractor(firings_file=self.firings)
        ef = int(1e6)
        recording_sub = si.SubRecordingExtractor(parent_recording=recording,
                                                 start_frame=0,
                                                 end_frame=ef)
        recording_sub = MemoryRecordingExtractor(
            parent_recording=recording_sub)
        sorting_sub = si.SubSortingExtractor(parent_sorting=sorting,
                                             start_frame=0,
                                             end_frame=ef)
        unit_ids = self.unit_ids
        if (not unit_ids) or (len(unit_ids) == 0):
            unit_ids = sorting.getUnitIds()

        channel_noise_levels = compute_channel_noise_levels(
            recording=recording)
        print('computing templates...')
        templates = compute_unit_templates(recording=recording_sub,
                                           sorting=sorting_sub,
                                           unit_ids=unit_ids)
        print('.')
        ret = []
        for i, unit_id in enumerate(unit_ids):
            template = templates[i]
            info0 = dict()
            info0['unit_id'] = int(unit_id)
            info0['snr'] = compute_template_snr(template, channel_noise_levels)
            peak_channel_index = np.argmax(np.max(np.abs(template), axis=1))
            info0['peak_channel'] = int(
                recording.getChannelIds()[peak_channel_index])
            train = sorting.getUnitSpikeTrain(unit_id=unit_id)
            info0['num_events'] = int(len(train))
            info0['firing_rate'] = float(
                len(train) /
                (recording.getNumFrames() / recording.getSamplingFrequency()))
            ret.append(info0)
        write_json_file(self.json_out, ret)
Beispiel #3
0
class ComputeUnitsInfo(mlpr.Processor):
  NAME='ComputeUnitsInfo'
  VERSION='0.1.5k'
  CONTAINER=_CONTAINER
  recording_dir=mlpr.Input(directory=True,description='Recording directory')
  channel_ids=mlpr.IntegerListParameter(description='List of channels to use.',optional=True,default=[])
  unit_ids=mlpr.IntegerListParameter(description='List of units to use.',optional=True,default=[])
  firings=mlpr.Input(description='Firings file')
  json_out=mlpr.Output('The info as a .json file')
  
  def run(self):
    import spikewidgets as sw
    
    R0=si.MdaRecordingExtractor(dataset_directory=self.recording_dir,download=True)
    if (self.channel_ids) and (len(self.channel_ids)>0):
      R0=si.SubRecordingExtractor(parent_recording=R0,channel_ids=self.channel_ids)
    
    recording = R0
    # recording=sw.lazyfilters.bandpass_filter(recording=R0,freq_min=300,freq_max=6000)

    sorting=si.MdaSortingExtractor(firings_file=self.firings)
    unit_ids=self.unit_ids
    if (not unit_ids) or (len(unit_ids)==0):
      unit_ids=sorting.getUnitIds()
  
    channel_noise_levels=compute_channel_noise_levels(recording=recording)

    # No longer use subset to compute the templates
    templates=compute_unit_templates(recording=recording,sorting=sorting,unit_ids=unit_ids,max_num=100)

    ret=[]
    for i,unit_id in enumerate(unit_ids):
      template=templates[i]
      max_p2p_amps_on_channels=np.max(template,axis=1)-np.min(template,axis=1)
      peak_channel_index=np.argmax(max_p2p_amps_on_channels)
      peak_channel=recording.getChannelIds()[peak_channel_index]
      R1=si.SubRecordingExtractor(parent_recording=recording,channel_ids=[peak_channel_index])
      R1f=sw.lazyfilters.bandpass_filter(recording=R1,freq_min=300,freq_max=6000)
      templates2=compute_unit_templates(recording=R1f,sorting=sorting,unit_ids=[unit_id],max_num=100)
      template2=templates2[0]
      info0=dict()
      info0['unit_id']=int(unit_id)
      info0['snr']=np.max(np.abs(template2))/channel_noise_levels[peak_channel_index]
      #info0['snr']=compute_template_snr(template,channel_noise_levels)
      #peak_channel_index=np.argmax(np.max(np.abs(template),axis=1))
      info0['peak_channel']=int(recording.getChannelIds()[peak_channel])
      train=sorting.getUnitSpikeTrain(unit_id=unit_id)
      info0['num_events']=int(len(train))
      info0['firing_rate']=float(len(train)/(recording.getNumFrames()/recording.getSamplingFrequency()))
      ret.append(info0)
    write_json_file(self.json_out,ret)
Beispiel #4
0
class KiloSort(mlpr.Processor):
    NAME = 'KiloSort'
    VERSION = '0.2.0'  # wrapper VERSION
    ADDITIONAL_FILES = ['*.m']
    ENVIRONMENT_VARIABLES = [
        'NUM_WORKERS', 'MKL_NUM_THREADS', 'NUMEXPR_NUM_THREADS', 'OMP_NUM_THREADS']
    CONTAINER = None
    CONTAINER_SHARE_ID = None

    recording_dir = mlpr.Input('Directory of recording', directory=True)
    channels = mlpr.IntegerListParameter(
        description='List of channels to use.', optional=True, default=[])
    firings_out = mlpr.Output('Output firings file')

    detect_sign = mlpr.IntegerParameter(
        'Use -1 or 1, depending on the sign of the spikes in the recording')
    adjacency_radius = mlpr.FloatParameter(
        'Use -1 to include all channels in every neighborhood')
    detect_threshold = mlpr.FloatParameter(
        optional=True, default=3, description='')
    # prm_template_name=mlpr.StringParameter(optional=False,description='TODO')
    freq_min = mlpr.FloatParameter(
        optional=True, default=300, description='Use 0 for no bandpass filtering')
    freq_max = mlpr.FloatParameter(
        optional=True, default=6000, description='Use 0 for no bandpass filtering')
    merge_thresh = mlpr.FloatParameter(
        optional=True, default=0.98, description='TODO')
    pc_per_chan = mlpr.IntegerParameter(
        optional=True, default=3, description='TODO')

    def run(self):
        code = ''.join(random.choice(string.ascii_uppercase)
                       for x in range(10))
        tmpdir = os.environ.get('TEMPDIR', '/tmp')+'/kilosort-tmp-'+code

        try:
            recording = se.MdaRecordingExtractor(self.recording_dir)
            if len(self.channels) > 0:
                recording = se.SubRecordingExtractor(
                    parent_recording=recording, channel_ids=self.channels)
            if not os.path.exists(tmpdir):
                os.mkdir(tmpdir)
            sorting = kilosort_helper(
                recording=recording,
                tmpdir=tmpdir,
                detect_sign=self.detect_sign,
                adjacency_radius=self.adjacency_radius,
                detect_threshold=self.detect_threshold,
                merge_thresh=self.merge_thresh,
                freq_min=self.freq_min,
                freq_max=self.freq_max,
                pc_per_chan=self.pc_per_chan
            )
            se.MdaSortingExtractor.writeSorting(
                sorting=sorting, save_path=self.firings_out)
        except:
            if os.path.exists(tmpdir):
                shutil.rmtree(tmpdir)
            raise
        shutil.rmtree(tmpdir)
class CreateTimeseriesPlot(mlpr.Processor):
    NAME = 'CreateTimeseriesPlot'
    VERSION = '0.1.7'
    recording_dir = mlpr.Input(directory=True,
                               description='Recording directory')
    channels = mlpr.IntegerListParameter(
        description='List of channels to use.', optional=True, default=[])
    jpg_out = mlpr.Output('The plot as a .jpg file')

    def run(self):
        R0 = si.MdaRecordingExtractor(dataset_directory=self.recording_dir,
                                      download=False)
        if len(self.channels) > 0:
            R0 = si.SubRecordingExtractor(parent_recording=R0,
                                          channel_ids=self.channels)
        R = sw.lazyfilters.bandpass_filter(recording=R0,
                                           freq_min=300,
                                           freq_max=6000)
        N = R.getNumFrames()
        N2 = int(N / 2)
        channels = R.getChannelIds()
        if len(channels) > 20: channels = channels[0:20]
        sw.TimeseriesWidget(recording=R,
                            trange=[N2 - 4000, N2 + 0],
                            channels=channels,
                            width=12,
                            height=5).plot()
        save_plot(self.jpg_out)
class GenSortingComparisonTable(mlpr.Processor):
    VERSION = '0.2.6'
    firings = mlpr.Input('Firings file (sorting)')
    firings_true = mlpr.Input('True firings file')
    units_true = mlpr.IntegerListParameter('List of true units to consider')
    json_out = mlpr.Output(
        'Table as .json file produced from pandas dataframe')
    html_out = mlpr.Output(
        'Table as .html file produced from pandas dataframe')
    # CONTAINER = 'sha1://5627c39b9bd729fc011cbfce6e8a7c37f8bcbc6b/spikeforest_basic.simg'
    # CONTAINER = 'sha1://0944f052e22de0f186bb6c5cb2814a71f118f2d1/spikeforest_basic.simg'  # MAY26JJJ
    CONTAINER = 'sha1://4904b8f914eb159618b6579fb9ba07b269bb2c61/06-26-2019/spikeforest_basic.simg'

    def run(self):
        print(
            'GenSortingComparisonTable: firings={}, firings_true={}, units_true={}'
            .format(self.firings, self.firings_true, self.units_true))
        sorting = SFMdaSortingExtractor(firings_file=self.firings)
        sorting_true = SFMdaSortingExtractor(firings_file=self.firings_true)
        if (self.units_true is not None) and (len(self.units_true) > 0):
            sorting_true = si.SubSortingExtractor(parent_sorting=sorting_true,
                                                  unit_ids=self.units_true)

        SC = SortingComparison(sorting_true, sorting, delta_tp=30)
        df = get_comparison_data_frame(comparison=SC)
        # sw.SortingComparisonTable(comparison=SC).getDataframe()
        json = df.transpose().to_dict()
        html = df.to_html(index=False)
        _write_json_file(json, self.json_out)
        _write_json_file(html, self.html_out)
Beispiel #7
0
class YASS(mlpr.Processor):
    NAME = 'YASS'
    VERSION = '0.1.0'
    # used by container to pass the env variables
    ENVIRONMENT_VARIABLES = [
        'NUM_WORKERS', 'MKL_NUM_THREADS', 'NUMEXPR_NUM_THREADS', 'OMP_NUM_THREADS']
    ADDITIONAL_FILES = ['*.yaml']
    CONTAINER = 'sha1://087767605e10761331699dda29519444bbd823f4/02-12-2019/yass.simg'
    CONTAINER_SHARE_ID = '69432e9201d0'  # place to look for container

    recording_dir = mlpr.Input('Directory of recording', directory=True)
    channels = mlpr.IntegerListParameter(
        description='List of channels to use.', optional=True, default=[])
    firings_out = mlpr.Output('Output firings file')
    #paramfile_out = mlpr.Output('YASS yaml config file')

    detect_sign = mlpr.IntegerParameter(description='-1, 1, or 0')
    adjacency_radius = mlpr.FloatParameter(
        optional=True, default=100, description='Channel neighborhood adjacency radius corresponding to geom file')
    template_width_ms = mlpr.FloatParameter(
        optional=True, default=3, description='Spike width in milliseconds')
    filter = mlpr.BoolParameter(optional=True, default=True)

    def run(self):
        code = ''.join(random.choice(string.ascii_uppercase)
                       for x in range(10))
        tmpdir = os.environ.get('TEMPDIR', '/tmp')+'/yass-tmp-'+code

        #num_workers = os.environ.get('NUM_WORKERS', 1)
        #print('num_workers: {}'.format(num_workers))
        try:
            recording = se.MdaRecordingExtractor(self.recording_dir)
            if len(self.channels) > 0:
                recording = se.SubRecordingExtractor(
                    parent_recording=recording, channel_ids=self.channels)
            if not os.path.exists(tmpdir):
                os.mkdir(tmpdir)
            sorting, yaml_file = yass_helper(
                recording=recording,
                output_folder=tmpdir,
                probe_file=None,
                file_name=None,
                detect_sign=self.detect_sign,
                adjacency_radius=self.adjacency_radius,
                template_width_ms=self.template_width_ms,
                filter=self.filter)
            se.MdaSortingExtractor.writeSorting(
                sorting=sorting, save_path=self.firings_out)
            #shutil.copyfile(yaml_file, self.paramfile_out)
        except:
            if os.path.exists(tmpdir):
                shutil.rmtree(tmpdir)
            raise
        if not getattr(self, '_keep_temp_files', False):
            shutil.rmtree(tmpdir)
class PlotUnitWaveforms(mlpr.Processor):
    VERSION='0.1.0'
    recording_dir=mlpr.Input(directory=True,description='Recording directory')
    channels=mlpr.IntegerListParameter(description='List of channels to use.',optional=True,default=[])
    firings=mlpr.Input('Firings file (sorting)')
    plot_out=mlpr.Output('Plot as .jpg image file')
    
    def run(self):
        recording=si.MdaRecordingExtractor(dataset_directory=self.recording_dir)
        if len(self.channels)>0:
            recording=si.SubRecordingExtractor(parent_recording=recording,channel_ids=self.channels)
        sorting=si.MdaSortingExtractor(firings_file=self.firings)
        sw.UnitWaveformsWidget(recording=recording,sorting=sorting).plot()
        fname=save_plot(self.plot_out)
Beispiel #9
0
class ComputeUnitsInfo(mlpr.Processor):
    NAME = 'ComputeUnitsInfo'
    VERSION = '0.1.8'
    CONTAINER = _CONTAINER
    recording_dir = mlpr.Input(directory=True,
                               description='Recording directory')
    channel_ids = mlpr.IntegerListParameter(
        description='List of channels to use.', optional=True, default=[])
    unit_ids = mlpr.IntegerListParameter(description='List of units to use.',
                                         optional=True,
                                         default=[])
    firings = mlpr.Input(description='Firings file')
    json_out = mlpr.Output('The info as a .json file')

    def run(self):
        R0 = SFMdaRecordingExtractor(dataset_directory=self.recording_dir,
                                     download=True)
        sorting = SFMdaSortingExtractor(firings_file=self.firings)
        ret = compute_units_info(recording=R0,
                                 sorting=sorting,
                                 channel_ids=self.channel_ids,
                                 unit_ids=self.unit_ids)
        write_json_file(self.json_out, ret)
class SpykingCircus(mlpr.Processor):
    NAME='SpykingCircus'
    VERSION='0.1.2'
    
    recording_dir=mlpr.Input('Directory of recording',directory=True)
    channels=mlpr.IntegerListParameter(description='List of channels to use.',optional=True,default=[])
    firings_out=mlpr.Output('Output firings file')
    
    detect_sign=mlpr.IntegerParameter(description='-1, 1, or 0')
    adjacency_radius=mlpr.FloatParameter(optional=True,default=100,description='Channel neighborhood adjacency radius corresponding to geom file')
    spike_thresh=mlpr.FloatParameter(optional=True,default=6,description='Threshold for detection')
    template_width_ms=mlpr.FloatParameter(optional=True,default=3,description='Spyking circus parameter')
    filter=mlpr.BoolParameter(optional=True,default=True)
    whitening_max_elts=mlpr.IntegerParameter(optional=True,default=1000,description='I believe it relates to subsampling and affects compute time')
    clustering_max_elts=mlpr.IntegerParameter(optional=True,default=10000,description='I believe it relates to subsampling and affects compute time')
    
    def run(self):
        code=''.join(random.choice(string.ascii_uppercase) for x in range(10))
        tmpdir=os.environ.get('TEMPDIR','/tmp')+'/ironclust-tmp-'+code
        
        num_workers=os.environ.get('NUM_WORKERS',2)
            
        try:
            recording=si.MdaRecordingExtractor(self.recording_dir)
            if len(self.channels)>0:
              recording=si.SubRecordingExtractor(parent_recording=recording,channel_ids=self.channels)
            if not os.path.exists(tmpdir):
                os.mkdir(tmpdir)
            sorting=sf.sorters.spyking_circus(
                recording=recording,
                output_folder=tmpdir,
                probe_file=None,
                file_name=None,
                detect_sign=self.detect_sign,
                adjacency_radius=self.adjacency_radius,
                spike_thresh=self.spike_thresh,
                template_width_ms=self.template_width_ms,
                filter=self.filter,
                merge_spikes=True,
                n_cores=num_workers,
                electrode_dimensions=None,
                whitening_max_elts=self.whitening_max_elts,
                clustering_max_elts=self.clustering_max_elts
            )
            si.MdaSortingExtractor.writeSorting(sorting=sorting,save_path=self.firings_out)
        except:
            if os.path.exists(tmpdir):
                shutil.rmtree(tmpdir)
            raise
        shutil.rmtree(tmpdir)
class IronClust(mlpr.Processor):
    NAME='IronClust'
    VERSION='4.2.6'
    
    recording_dir=mlpr.Input('Directory of recording',directory=True)
    channels=mlpr.IntegerListParameter(description='List of channels to use.',optional=True,default=[])
    firings_out=mlpr.Output('Output firings file')
    
    detect_sign=mlpr.IntegerParameter('Use -1, 0, or 1, depending on the sign of the spikes in the recording')
    adjacency_radius=mlpr.FloatParameter('Use -1 to include all channels in every neighborhood')
    detect_threshold=mlpr.FloatParameter(optional=True,default=3,description='')
    prm_template_name=mlpr.StringParameter(optional=False,description='TODO')
    freq_min=mlpr.FloatParameter(optional=True,default=300,description='Use 0 for no bandpass filtering')
    freq_max=mlpr.FloatParameter(optional=True,default=6000,description='Use 0 for no bandpass filtering')
    merge_thresh=mlpr.FloatParameter(optional=True,default=0.98,description='TODO')
    pc_per_chan=mlpr.IntegerParameter(optional=True,default=3,description='TODO')
    
    def run(self):
        ironclust_src=os.environ.get('IRONCLUST_SRC',None)
        if not ironclust_src:
            raise Exception('Environment variable not set: IRONCLUST_SRC')
        code=''.join(random.choice(string.ascii_uppercase) for x in range(10))
        tmpdir=os.environ.get('TEMPDIR','/tmp')+'/ironclust-tmp-'+code
            
        try:
            recording=si.MdaRecordingExtractor(self.recording_dir)
            if len(self.channels)>0:
              recording=si.SubRecordingExtractor(parent_recording=recording,channel_ids=self.channels)
            if not os.path.exists(tmpdir):
                os.mkdir(tmpdir)
            sorting=sf.sorters.ironclust(
                recording=recording,
                tmpdir=tmpdir, ## TODO
                detect_sign=self.detect_sign,
                adjacency_radius=self.adjacency_radius,
                detect_threshold=self.detect_threshold,
                merge_thresh=self.merge_thresh,
                freq_min=self.freq_min,
                freq_max=self.freq_max,
                pc_per_chan=self.pc_per_chan,
                prm_template_name=self.prm_template_name,
                ironclust_src=ironclust_src
            )
            si.MdaSortingExtractor.writeSorting(sorting=sorting,save_path=self.firings_out)
        except:
            if os.path.exists(tmpdir):
                shutil.rmtree(tmpdir)
            raise
        shutil.rmtree(tmpdir)
Beispiel #12
0
class RepeatText(mlpr.Processor):
    textfile = mlpr.Input(help="input text file")
    textfile_out = mlpr.Output(help="output text file")
    num_repeats = mlpr.IntegerListParameter(
        help="Number of times to repeat the text")

    def run(self):
        assert self.num_repeats >= 0
        with open(self.textfile, 'r') as f:
            txt = f.read()
        txt2 = ''
        for _ in range(self.num_repeats):
            txt2 = txt2 + txt
        with open(self.textfile_out, 'w') as f:
            f.write(txt2)
class PlotAutoCorrelograms(mlpr.Processor):
    NAME='spikeforest.PlotAutoCorrelograms'
    VERSION='0.1.0'
    recording_dir=mlpr.Input(directory=True,description='Recording directory')
    channels=mlpr.IntegerListParameter(description='List of channels to use.',optional=True,default=[])
    firings=mlpr.Input('Firings file (sorting)')
    plot_out=mlpr.Output('Plot as .jpg image file')
    
    def run(self):
        recording=si.MdaRecordingExtractor(dataset_directory=self.recording_dir,download=False)
        if len(self.channels)>0:
            recording=si.SubRecordingExtractor(parent_recording=recording,channel_ids=self.channels)
        sorting=si.MdaSortingExtractor(firings_file=self.firings)
        sw.CrossCorrelogramsWidget(samplerate=recording.getSamplingFrequency(),sorting=sorting).plot()
        fname=save_plot(self.plot_out)
class ComputeRecordingInfo(mlpr.Processor):
  NAME='ComputeRecordingInfo'
  VERSION='0.1.0'
  recording_dir=mlpr.Input(directory=True,description='Recording directory')
  channels=mlpr.IntegerListParameter(description='List of channels to use.',optional=True,default=[])
  json_out=mlpr.Output('Info in .json file')
    
  def run(self):
    ret={}
    recording=si.MdaRecordingExtractor(dataset_directory=self.recording_dir,download=False)
    if len(self.channels)>0:
      recording=si.SubRecordingExtractor(parent_recording=recording,channel_ids=self.channels)
    ret['samplerate']=recording.getSamplingFrequency()
    ret['num_channels']=len(recording.getChannelIds())
    ret['duration_sec']=recording.getNumFrames()/ret['samplerate']
    write_json_file(self.json_out,ret)
class GenSortingComparisonTableNew(mlpr.Processor):
    VERSION = '0.3.1'
    firings = mlpr.Input('Firings file (sorting)')
    firings_true = mlpr.Input('True firings file')
    units_true = mlpr.IntegerListParameter('List of true units to consider')
    json_out = mlpr.Output(
        'Table as .json file produced from pandas dataframe')
    html_out = mlpr.Output(
        'Table as .html file produced from pandas dataframe')
    # CONTAINER = 'sha1://5627c39b9bd729fc011cbfce6e8a7c37f8bcbc6b/spikeforest_basic.simg'
    # CONTAINER = 'sha1://0944f052e22de0f186bb6c5cb2814a71f118f2d1/spikeforest_basic.simg'  # MAY26JJJ
    CONTAINER = 'sha1://4904b8f914eb159618b6579fb9ba07b269bb2c61/06-26-2019/spikeforest_basic.simg'

    def run(self):
        print(
            'GenSortingComparisonTable: firings={}, firings_true={}, units_true={}'
            .format(self.firings, self.firings_true, self.units_true))
        sorting = SFMdaSortingExtractor(firings_file=self.firings)
        sorting_true = SFMdaSortingExtractor(firings_file=self.firings_true)
        if (self.units_true is not None) and (len(self.units_true) > 0):
            sorting_true = si.SubSortingExtractor(parent_sorting=sorting_true,
                                                  unit_ids=self.units_true)

        SC = st.comparison.compare_sorter_to_ground_truth(
            gt_sorting=sorting_true,
            tested_sorting=sorting,
            delta_time=0.3,
            min_accuracy=0,
            compute_misclassification=False,
            exhaustive_gt=False  # Fix this in future
        )
        df = pd.concat([SC.count, SC.get_performance()], axis=1).reset_index()

        df = df.rename(columns=dict(gt_unit_id='unit_id',
                                    fp='num_false_positives',
                                    fn='num_false_negatives',
                                    tested_id='best_unit',
                                    tp='num_matches'))
        df['matched_unit'] = df['best_unit']
        df['f_p'] = 1 - df['precision']
        df['f_n'] = 1 - df['recall']

        # sw.SortingComparisonTable(comparison=SC).getDataframe()
        json = df.transpose().to_dict()
        html = df.to_html(index=False)
        _write_json_file(json, self.json_out)
        _write_json_file(html, self.html_out)
Beispiel #16
0
class GenSortingComparisonTable(mlpr.Processor):
    VERSION='0.1.1'
    firings=mlpr.Input('Firings file (sorting)')
    firings_true=mlpr.Input('True firings file')
    units_true=mlpr.IntegerListParameter('List of true units to consider')
    json_out=mlpr.Output('Table as .json file produced from pandas dataframe')
    html_out=mlpr.Output('Table as .html file produced from pandas dataframe')
    
    def run(self):
        sorting=si.MdaSortingExtractor(firings_file=self.firings)
        sorting_true=si.MdaSortingExtractor(firings_file=self.firings_true)
        if len(self.units_true)>0:
            sorting_true=si.SubSortingExtractor(parent_sorting=sorting_true,unit_ids=self.units_true)
        SC=st.comparison.SortingComparison(sorting_true,sorting)
        df=sw.SortingComparisonTable(comparison=SC).getDataframe()
        json=df.transpose().to_dict()
        html=df.to_html(index=False)
        _write_json_file(json,self.json_out)
        _write_json_file(html,self.html_out)
class MountainSort4(mlpr.Processor):
    NAME='MountainSort4'
    VERSION='4.0.1'
    
    recording_dir=mlpr.Input('Directory of recording',directory=True)
    channels=mlpr.IntegerListParameter(description='List of channels to use.',optional=True,default=[])
    firings_out=mlpr.Output('Output firings file')
    
    detect_sign=mlpr.IntegerParameter('Use -1, 0, or 1, depending on the sign of the spikes in the recording')
    adjacency_radius=mlpr.FloatParameter('Use -1 to include all channels in every neighborhood')
    freq_min=mlpr.FloatParameter(optional=True,default=300,description='Use 0 for no bandpass filtering')
    freq_max=mlpr.FloatParameter(optional=True,default=6000,description='Use 0 for no bandpass filtering')
    whiten=mlpr.BoolParameter(optional=True,default=True,description='Whether to do channel whitening as part of preprocessing')
    clip_size=mlpr.IntegerParameter(optional=True,default=50,description='')
    detect_threshold=mlpr.FloatParameter(optional=True,default=3,description='')
    detect_interval=mlpr.IntegerParameter(optional=True,default=10,description='Minimum number of timepoints between events detected on the same channel')
    noise_overlap_threshold=mlpr.FloatParameter(optional=True,default=0.15,description='Use None for no automated curation')
    
    def run(self):
        recording=si.MdaRecordingExtractor(self.recording_dir)
        if len(self.channels)>0:
          recording=si.SubRecordingExtractor(parent_recording=recording,channel_ids=self.channels)
        num_workers=os.environ.get('NUM_WORKERS',None)
        if num_workers:
            num_workers=int(num_workers)
        sorting=sf.sorters.mountainsort4(
            recording=recording,
            detect_sign=self.detect_sign,
            adjacency_radius=self.adjacency_radius,
            freq_min=self.freq_min,
            freq_max=self.freq_max,
            whiten=self.whiten,
            clip_size=self.clip_size,
            detect_threshold=self.detect_threshold,
            detect_interval=self.detect_interval,
            noise_overlap_threshold=self.noise_overlap_threshold,
            num_workers=num_workers
        )
        si.MdaSortingExtractor.writeSorting(sorting=sorting,save_path=self.firings_out)
Beispiel #18
0
class JRClust(mlpr.Processor):
    """
    JRClust v4 wrapper
      written by J. James Jun, May 8, 2019
      modified from `spikeforest/spikesorters/ironclust/ironclust.py`

    [Installation instruction in SpikeForest environment]
    1. Run `git clone https://github.com/JaneliaSciComp/JRCLUST`
    2. Activate conda environment for SpikeForest
    3. Create `JRCLUST_PATH` and `MDAIO_PATH`
    4. Flatiron execution: `module load matlab/R2019a cuda/10.0.130_410.48`
    DO NOT LOAD `module load gcc`. 
    DO NOT USE `matlab/R2018b` (it will crash while creating a parpool).

    See: James Jun, et al. Real-time spike sorting platform for
    high-density extracellular probes with ground-truth validation and drift correction
    https://github.com/JaneliaSciComp/JRCLUST
    """

    NAME = 'JRClust'
    VERSION = '0.1.3'
    ENVIRONMENT_VARIABLES = [
        'NUM_WORKERS', 'MKL_NUM_THREADS', 'NUMEXPR_NUM_THREADS',
        'OMP_NUM_THREADS', 'TEMPDIR'
    ]
    ADDITIONAL_FILES = ['*.m', '*.prm']
    CONTAINER = None

    recording_dir = mlpr.Input('Directory of recording', directory=True)
    channels = mlpr.IntegerListParameter(
        description='List of channels to use.', optional=True, default=[])
    firings_out = mlpr.Output('Output firings file')

    detect_sign = mlpr.IntegerParameter(
        optional=True,
        default=-1,
        description=
        'Use -1, 0, or 1, depending on the sign of the spikes in the recording'
    )
    adjacency_radius = mlpr.FloatParameter(optional=True,
                                           default=50,
                                           description='')
    detect_threshold = mlpr.FloatParameter(optional=True,
                                           default=4.5,
                                           description='detection threshold')
    freq_min = mlpr.FloatParameter(
        optional=True,
        default=300,
        description='Use 0 for no bandpass filtering')
    freq_max = mlpr.FloatParameter(
        optional=True,
        default=3000,
        description='Use 0 for no bandpass filtering')
    merge_thresh = mlpr.FloatParameter(
        optional=True,
        default=0.98,
        description='Threshold for automated merging')
    pc_per_chan = mlpr.IntegerParameter(
        optional=True,
        default=1,
        description='Number of principal components per channel')

    # added in version 0.2.4
    filter_type = mlpr.StringParameter(
        optional=True,
        default='bandpass',
        description='{none, bandpass, wiener, fftdiff, ndiff}')
    nDiffOrder = mlpr.FloatParameter(optional=True, default=2, description='')
    common_ref_type = mlpr.StringParameter(optional=True,
                                           default='none',
                                           description='{none, mean, median}')
    min_count = mlpr.IntegerParameter(optional=True,
                                      default=30,
                                      description='Minimum cluster size')
    fGpu = mlpr.IntegerParameter(optional=True,
                                 default=1,
                                 description='Use GPU if available')
    fParfor = mlpr.IntegerParameter(optional=True,
                                    default=0,
                                    description='Use parfor if available')
    feature_type = mlpr.StringParameter(
        optional=True,
        default='gpca',
        description='{gpca, pca, vpp, vmin, vminmax, cov, energy, xcov}')

    def run(self):
        tmpdir = _get_tmpdir('jrclust')

        try:
            recording = SFMdaRecordingExtractor(self.recording_dir)
            params = read_dataset_params(self.recording_dir)
            if len(self.channels) > 0:
                recording = se.SubRecordingExtractor(
                    parent_recording=recording, channel_ids=self.channels)
            if not os.path.exists(tmpdir):
                os.mkdir(tmpdir)

            all_params = dict()
            for param0 in self.PARAMETERS:
                all_params[param0.name] = getattr(self, param0.name)

            sorting = jrclust_helper(
                recording=recording,
                tmpdir=tmpdir,
                params=params,
                **all_params,
            )
            SFMdaSortingExtractor.write_sorting(sorting=sorting,
                                                save_path=self.firings_out)
        except:
            if os.path.exists(tmpdir):
                if not getattr(self, '_keep_temp_files', False):
                    shutil.rmtree(tmpdir)
            raise
        if not getattr(self, '_keep_temp_files', False):
            shutil.rmtree(tmpdir)
Beispiel #19
0
class TridesclousOld(mlpr.Processor):
    """
    tridesclous is one of the more convenient, fast and elegant
    spike sorters.

    Installation instruction
        >>> pip install https://github.com/tridesclous/tridesclous/archive/master.zip

    More information on tridesclous at:
    * https://github.com/tridesclous/tridesclous
    * https://tridesclous.readthedocs.io

    """

    NAME = 'Tridesclous'
    VERSION = '0.1.1'  # wrapper VERSION
    ADDITIONAL_FILES = []
    ENVIRONMENT_VARIABLES = [
        'NUM_WORKERS', 'MKL_NUM_THREADS', 'NUMEXPR_NUM_THREADS',
        'OMP_NUM_THREADS', 'TEMPDIR'
    ]
    CONTAINER = 'sha1://9fb4a9350492ee84c8ea5d8692434ecba3cf33da/2019-05-13/tridesclous.simg'
    LOCAL_MODULES = ['../../spikeforest']

    recording_dir = mlpr.Input('Directory of recording', directory=True)
    firings_out = mlpr.Output('Output firings file')
    channels = mlpr.IntegerListParameter(
        optional=True, default=[], description='List of channels to use.')
    detect_sign = mlpr.FloatParameter(optional=True,
                                      default=-1,
                                      description='')
    detection_threshold = mlpr.FloatParameter(optional=True,
                                              default=5.5,
                                              description='')
    freq_min = mlpr.FloatParameter(optional=True, default=400, description='')
    freq_max = mlpr.FloatParameter(optional=True, default=5000, description='')
    waveforms_n_left = mlpr.IntegerParameter(description='',
                                             optional=True,
                                             default=-45)
    waveforms_n_right = mlpr.IntegerParameter(description='',
                                              optional=True,
                                              default=60)
    align_waveform = mlpr.BoolParameter(description='',
                                        optional=True,
                                        default=False)
    common_ref_removal = mlpr.BoolParameter(description='',
                                            optional=True,
                                            default=False)
    peak_span = mlpr.FloatParameter(optional=True,
                                    default=.0002,
                                    description='')
    alien_value_threshold = mlpr.FloatParameter(optional=True,
                                                default=100,
                                                description='')

    def run(self):
        import tridesclous as tdc

        tmpdir = Path(_get_tmpdir('tdc'))
        recording = SFMdaRecordingExtractor(self.recording_dir)

        params = {
            'fullchain_kargs': {
                'duration': 300.,
                'preprocessor': {
                    'highpass_freq': self.freq_min,
                    'lowpass_freq': self.freq_max,
                    'smooth_size': 0,
                    'chunksize': 1024,
                    'lostfront_chunksize': 128,
                    'signalpreprocessor_engine': 'numpy',
                    'common_ref_removal': self.common_ref_removal,
                },
                'peak_detector': {
                    'peakdetector_engine': 'numpy',
                    'peak_sign': '-',
                    'relative_threshold': self.detection_threshold,
                    'peak_span': self.peak_span,
                },
                'noise_snippet': {
                    'nb_snippet': 300,
                },
                'extract_waveforms': {
                    'n_left': self.waveforms_n_left,
                    'n_right': self.waveforms_n_right,
                    'mode': 'rand',
                    'nb_max': 20000,
                    'align_waveform': self.align_waveform,
                },
                'clean_waveforms': {
                    'alien_value_threshold': self.alien_value_threshold,
                },
            },
            'feat_method': 'peak_max',
            'feat_kargs': {},
            'clust_method': 'sawchaincut',
            'clust_kargs': {
                'kde_bandwith': 1.
            },
        }

        # save prb file:
        probe_file = tmpdir / 'probe.prb'
        se.save_probe_file(recording, probe_file, format='spyking_circus')

        # source file
        if isinstance(recording,
                      se.BinDatRecordingExtractor) and recording._frame_first:
            # no need to copy
            raw_filename = recording._datfile
            dtype = recording._timeseries.dtype.str
            nb_chan = len(recording._channels)
            offset = recording._timeseries.offset
        else:
            # save binary file (chunk by hcunk) into a new file
            raw_filename = tmpdir / 'raw_signals.raw'
            n_chan = recording.get_num_channels()
            chunksize = 2**24 // n_chan
            se.write_binary_dat_format(recording,
                                       raw_filename,
                                       time_axis=0,
                                       dtype='float32',
                                       chunksize=chunksize)
            dtype = 'float32'
            offset = 0

        # initialize source and probe file
        tdc_dataio = tdc.DataIO(dirname=str(tmpdir))
        nb_chan = recording.get_num_channels()

        tdc_dataio.set_data_source(
            type='RawData',
            filenames=[str(raw_filename)],
            dtype=dtype,
            sample_rate=recording.get_sampling_frequency(),
            total_channel=nb_chan,
            offset=offset)
        tdc_dataio.set_probe_file(str(probe_file))

        try:
            sorting = tdc_helper(tmpdir=tmpdir,
                                 params=params,
                                 recording=recording)
            SFMdaSortingExtractor.write_sorting(sorting=sorting,
                                                save_path=self.firings_out)
        except:
            if os.path.exists(tmpdir):
                if not getattr(self, '_keep_temp_files', False):
                    shutil.rmtree(tmpdir)
            raise
        if not getattr(self, '_keep_temp_files', False):
            shutil.rmtree(tmpdir)
Beispiel #20
0
class KiloSort2Old(mlpr.Processor):
    """
    KiloSort2 wrapper for SpikeForest framework
      written by J. James Jun, May 7, 2019
      modified from `spiketoolkit/sorters/Kilosort`
      to be made compatible with SpikeForest

    [Prerequisite]
    1. MATLAB (Tested on R2018b)
    2. CUDA Toolkit v9.1

    [Installation instruction in SpikeForest environment]
    1. Run `git clone https://github.com/alexmorley/Kilosort2.git`
      Kilosort2 currently doesn't work on tetrodes and low-channel count probes (as of May 7, 2019).
      Clone from Alex Morley's repository that fixed these issues.
      Original Kilosort2 code can be obtained from `https://github.com/MouseLand/Kilosort2.git`
    2. (optional) If Alex Morley's latest version doesn't work with SpikeForest, run
        `git checkout 43cbbfff89b9c88cdeb147ffd4ac35bfde9c7956`
    3. In Matlab, run `CUDA/mexGPUall` to compile all CUDA codes
    4. Add `KILOSORT2_PATH=...` in your .bashrc file.
    """

    NAME = 'KiloSort2'
    VERSION = '0.3.3'  # wrapper VERSION
    ADDITIONAL_FILES = ['*.m']
    ENVIRONMENT_VARIABLES = [
        'NUM_WORKERS', 'MKL_NUM_THREADS', 'NUMEXPR_NUM_THREADS',
        'OMP_NUM_THREADS'
    ]
    CONTAINER = None

    recording_dir = mlpr.Input('Directory of recording', directory=True)
    channels = mlpr.IntegerListParameter(
        description='List of channels to use.', optional=True, default=[])
    firings_out = mlpr.Output('Output firings file')

    detect_sign = mlpr.IntegerParameter(
        default=-1,
        optional=True,
        description=
        'Use -1 or 1, depending on the sign of the spikes in the recording')
    adjacency_radius = mlpr.FloatParameter(
        default=30,
        optional=True,
        description='Use -1 to include all channels in every neighborhood')
    detect_threshold = mlpr.FloatParameter(optional=True,
                                           default=6,
                                           description='')
    # prm_template_name=mlpr.StringParameter(optional=False,description='TODO')
    freq_min = mlpr.FloatParameter(
        optional=True,
        default=150,
        description='Use 0 for no bandpass filtering')
    freq_max = mlpr.FloatParameter(
        optional=True,
        default=6000,
        description='Use 0 for no bandpass filtering')
    merge_thresh = mlpr.FloatParameter(optional=True,
                                       default=0.98,
                                       description='TODO')
    pc_per_chan = mlpr.IntegerParameter(optional=True,
                                        default=3,
                                        description='TODO')
    minFR = mlpr.FloatParameter(
        default=1 / 50,
        optional=True,
        description=
        'minimum spike rate (Hz), if a cluster falls below this for too long it gets removed'
    )

    @staticmethod
    def install():
        print('Auto-installing kilosort.')
        return install_kilosort2(
            repo='https://github.com/alexmorley/Kilosort2',
            commit='43cbbfff89b9c88cdeb147ffd4ac35bfde9c7956')

    def run(self):
        _keep_temp_files = True

        code = ''.join(
            random.choice(string.ascii_uppercase) for x in range(10))
        tmpdir = os.environ.get('TEMPDIR', '/tmp') + '/kilosort2-tmp-' + code

        try:
            recording = SFMdaRecordingExtractor(self.recording_dir)
            if len(self.channels) > 0:
                recording = se.SubRecordingExtractor(
                    parent_recording=recording, channel_ids=self.channels)
            if not os.path.exists(tmpdir):
                os.mkdir(tmpdir)
            sorting = kilosort2_helper(recording=recording,
                                       tmpdir=tmpdir,
                                       detect_sign=self.detect_sign,
                                       adjacency_radius=self.adjacency_radius,
                                       detect_threshold=self.detect_threshold,
                                       merge_thresh=self.merge_thresh,
                                       freq_min=self.freq_min,
                                       freq_max=self.freq_max,
                                       pc_per_chan=self.pc_per_chan,
                                       minFR=self.minFR)
            SFMdaSortingExtractor.write_sorting(sorting=sorting,
                                                save_path=self.firings_out)
        except:
            if os.path.exists(tmpdir):
                if not _keep_temp_files:
                    print('removing tmpdir1')
                    shutil.rmtree(tmpdir)
            raise
        if not _keep_temp_files:
            print('removing tmpdir2')
            shutil.rmtree(tmpdir)
Beispiel #21
0
class IronClust(mlpr.Processor):
    NAME = 'IronClust'
    VERSION = '0.7.9'
    ENVIRONMENT_VARIABLES = [
        'NUM_WORKERS', 'MKL_NUM_THREADS', 'NUMEXPR_NUM_THREADS', 'OMP_NUM_THREADS', 'TEMPDIR']
    CONTAINER: Union[str, None] = None

    recording_dir = mlpr.Input('Directory of recording', directory=True)
    firings_out = mlpr.Output('Output firings file')
    channels = mlpr.IntegerListParameter(
        optional=True, default=[],
        description='List of channels to use.'
    )
    detect_sign = mlpr.IntegerParameter(
        optional=True, default=-1,
        description='Use -1, 0, or 1, depending on the sign of the spikes in the recording'
    )
    adjacency_radius = mlpr.FloatParameter(
        optional=True, default=50,
        description='Use -1 to include all channels in every neighborhood'
    )
    adjacency_radius_out = mlpr.FloatParameter(
        optional=True, default=100,
        description='Use -1 to include all channels in every neighborhood'
    )
    detect_threshold = mlpr.FloatParameter(
        optional=True, default=4,
        description='detection threshold'
    )
    prm_template_name = mlpr.StringParameter(
        optional=True, default='',
        description='.prm template file name'
    )
    freq_min = mlpr.FloatParameter(
        optional=True, default=300,
        description='Use 0 for no bandpass filtering'
    )
    freq_max = mlpr.FloatParameter(
        optional=True, default=8000,
        description='Use 0 for no bandpass filtering'
    )
    merge_thresh = mlpr.FloatParameter(
        optional=True, default=0.98,
        description='Threshold for automated merging'
    )
    pc_per_chan = mlpr.IntegerParameter(
        optional=True, default=0,
        description='Number of principal components per channel'
    )  

    # added in version 0.2.4
    whiten = mlpr.BoolParameter(
        optional=True, default=False, description='Whether to do channel whitening as part of preprocessing')
    filter_type = mlpr.StringParameter(
        optional=True, default='bandpass', description='{none, bandpass, wiener, fftdiff, ndiff}')
    filter_detect_type = mlpr.StringParameter(
        optional=True, default='none', description='{none, bandpass, wiener, fftdiff, ndiff}')
    common_ref_type = mlpr.StringParameter(
        optional=True, default='mean', description='{none, mean, median}')
    batch_sec_drift = mlpr.FloatParameter(
        optional=True, default=300, description='batch duration in seconds. clustering time duration')
    step_sec_drift = mlpr.FloatParameter(
        optional=True, default=20, description='compute anatomical similarity every n sec')
    knn = mlpr.IntegerParameter(
        optional=True, default=30, description='K nearest neighbors')
    min_count = mlpr.IntegerParameter(
        optional=True, default=30, description='Minimum cluster size')
    fGpu = mlpr.BoolParameter(
        optional=True, default=True, description='Use GPU if available')
    fft_thresh = mlpr.FloatParameter(
        optional=True, default=8, description='FFT-based noise peak threshold')
    fft_thresh_low = mlpr.FloatParameter(
        optional=True, default=0, description='FFT-based noise peak lower threshold (set to 0 to disable dual thresholding scheme)')
    nSites_whiten = mlpr.IntegerParameter(
        optional=True, default=32, description='Number of adjacent channels to whiten')        
    feature_type = mlpr.StringParameter(
        optional=True, default='gpca', description='{gpca, pca, vpp, vmin, vminmax, cov, energy, xcov}')
    delta_cut = mlpr.FloatParameter(
        optional=True, default=1, description='Cluster detection threshold (delta-cutoff)')
    post_merge_mode = mlpr.IntegerParameter(
        optional=True, default=1, description='post_merge_mode')
    sort_mode = mlpr.IntegerParameter(
        optional=True, default=1, description='sort_mode')

    @staticmethod
    def install():
        print('Auto-installing ironclust.')
        return install_ironclust(commit='72bb6d097e0875d6cfe52bddf5f782e667e1b042')

    def run(self):
        timer = time.time()
        import spikesorters as sorters
        print('IronClust......')

        try:
            ironclust_path = IronClust.install()
        except:
            traceback.print_exc()
            raise Exception('Problem installing ironclust.')
        sorters.IronClustSorter.set_ironclust_path(ironclust_path)


        recording = SFMdaRecordingExtractor(self.recording_dir)
        code = ''.join(random.choice(string.ascii_uppercase)
                       for x in range(10))
        tmpdir = os.environ.get('TEMPDIR', '/tmp') + '/ironclust-' + code

        sorter = sorters.IronClustSorter(
            recording=recording,
            output_folder=tmpdir,
            debug=True,
            delete_output_folder = False # will be taken care by _keep_temp_files one step above
        )

        sorter.set_params(
            detect_sign=self.detect_sign,
            adjacency_radius=self.adjacency_radius,
            adjacency_radius_out=self.adjacency_radius_out,
            detect_threshold=self.detect_threshold,
            prm_template_name=self.prm_template_name,
            freq_min=self.freq_min,
            freq_max=self.freq_max,
            merge_thresh=self.merge_thresh,
            pc_per_chan=self.pc_per_chan,
            whiten=self.whiten,
            filter_type=self.filter_type,
            filter_detect_type=self.filter_detect_type,
            common_ref_type=self.common_ref_type,
            batch_sec_drift=self.batch_sec_drift,
            step_sec_drift=self.step_sec_drift,
            knn=self.knn,
            min_count=self.min_count,
            fGpu=self.fGpu,
            fft_thresh=self.fft_thresh,
            fft_thresh_low=self.fft_thresh_low,
            nSites_whiten=self.nSites_whiten,
            feature_type=self.feature_type,
            delta_cut=self.delta_cut,
            post_merge_mode=self.post_merge_mode,
            sort_mode=self.sort_mode
        )     
        timer = sorter.run()
        print('#SF-SORTER-RUNTIME#{:.3f}#'.format(timer))
        sorting = sorter.get_result()

        SFMdaSortingExtractor.write_sorting(
            sorting=sorting, save_path=self.firings_out)
Beispiel #22
0
class SpykingCircus(mlpr.Processor):
    NAME = 'SpykingCircus'
    VERSION = '0.2.2'
    ENVIRONMENT_VARIABLES = [
        'NUM_WORKERS', 'MKL_NUM_THREADS', 'NUMEXPR_NUM_THREADS',
        'OMP_NUM_THREADS'
    ]
    ADDITIONAL_FILES = ['*.params']
    CONTAINER = 'sha1://914becce45aec56a84dd1dd4bca4037b09c50373/02-12-2019/spyking_circus.simg'
    CONTAINER_SHARE_ID = '69432e9201d0'  # place to look for container

    recording_dir = mlpr.Input('Directory of recording', directory=True)
    channels = mlpr.IntegerListParameter(
        description='List of channels to use.', optional=True, default=[])
    firings_out = mlpr.Output('Output firings file')

    detect_sign = mlpr.IntegerParameter(description='-1, 1, or 0')
    adjacency_radius = mlpr.FloatParameter(
        optional=True,
        default=100,
        description=
        'Channel neighborhood adjacency radius corresponding to geom file')
    spike_thresh = mlpr.FloatParameter(optional=True,
                                       default=6,
                                       description='Threshold for detection')
    template_width_ms = mlpr.FloatParameter(
        optional=True, default=3, description='Spyking circus parameter')
    filter = mlpr.BoolParameter(optional=True, default=True)
    whitening_max_elts = mlpr.IntegerParameter(
        optional=True,
        default=1000,
        description=
        'I believe it relates to subsampling and affects compute time')
    clustering_max_elts = mlpr.IntegerParameter(
        optional=True,
        default=10000,
        description=
        'I believe it relates to subsampling and affects compute time')

    def run(self):
        code = ''.join(
            random.choice(string.ascii_uppercase) for x in range(10))
        tmpdir = os.environ.get('TEMPDIR',
                                '/tmp') + '/spyking-circus-tmp-' + code

        num_workers = os.environ.get('NUM_WORKERS', 1)

        try:
            recording = se.MdaRecordingExtractor(self.recording_dir)
            if len(self.channels) > 0:
                recording = se.SubRecordingExtractor(
                    parent_recording=recording, channel_ids=self.channels)
            if not os.path.exists(tmpdir):
                os.mkdir(tmpdir)
            sorting = spyking_circus(
                recording=recording,
                output_folder=tmpdir,
                probe_file=None,
                file_name=None,
                detect_sign=self.detect_sign,
                adjacency_radius=self.adjacency_radius,
                spike_thresh=self.spike_thresh,
                template_width_ms=self.template_width_ms,
                filter=self.filter,
                merge_spikes=True,
                n_cores=num_workers,
                electrode_dimensions=None,
                whitening_max_elts=self.whitening_max_elts,
                clustering_max_elts=self.clustering_max_elts,
            )
            se.MdaSortingExtractor.writeSorting(sorting=sorting,
                                                save_path=self.firings_out)
        except:
            if not getattr(self, '_keep_temp_files', False):
                if os.path.exists(tmpdir):
                    shutil.rmtree(tmpdir)
            raise
        if not getattr(self, '_keep_temp_files', False):
            shutil.rmtree(tmpdir)
Beispiel #23
0
class KiloSortOld(mlpr.Processor):
    """
    Kilosort wrapper for SpikeForest framework
      written by J. James Jun, May 21, 2019

    [Prerequisite]
    1. MATLAB (Tested on R2019a)
    2. CUDA Toolkit v10.0 

    [Optional: Installation instruction in SpikeForest environment]
    1. Run `git clone https://github.com/cortex-lab/KiloSort.git`
    3. In Matlab, run `CUDA/mexGPUall` to compile all CUDA codes
    4. Add `KILOSORT_PATH_DEV=...` in your .bashrc file.
    """

    NAME = 'KiloSort'
    VERSION = '0.2.4'  # wrapper VERSION
    ADDITIONAL_FILES = ['*.m']
    ENVIRONMENT_VARIABLES = [
        'NUM_WORKERS', 'MKL_NUM_THREADS', 'NUMEXPR_NUM_THREADS',
        'OMP_NUM_THREADS'
    ]
    CONTAINER = None
    LOCAL_MODULES = ['../../spikeforest']

    recording_dir = mlpr.Input('Directory of recording', directory=True)
    channels = mlpr.IntegerListParameter(
        description='List of channels to use.', optional=True, default=[])
    firings_out = mlpr.Output('Output firings file')

    detect_sign = mlpr.IntegerParameter(
        optional=True,
        default=-1,
        description=
        'Use -1 or 1, depending on the sign of the spikes in the recording')
    adjacency_radius = mlpr.FloatParameter(optional=True,
                                           default=-1,
                                           description='Currently unused')
    detect_threshold = mlpr.FloatParameter(optional=True,
                                           default=3,
                                           description='')
    # prm_template_name=mlpr.StringParameter(optional=False,description='TODO')
    freq_min = mlpr.FloatParameter(
        optional=True,
        default=300,
        description='Use 0 for no bandpass filtering')
    freq_max = mlpr.FloatParameter(
        optional=True,
        default=6000,
        description='Use 0 for no bandpass filtering')
    merge_thresh = mlpr.FloatParameter(optional=True,
                                       default=0.98,
                                       description='TODO')
    pc_per_chan = mlpr.IntegerParameter(optional=True,
                                        default=3,
                                        description='TODO')

    @staticmethod
    def install():
        print('Auto-installing kilosort.')
        return install_kilosort(
            repo='https://github.com/cortex-lab/KiloSort.git',
            commit='3f33771f8fdf8c3846a7f8a75cc8c318b44ed48c')

    def run(self):
        keep_temp_files = False
        code = ''.join(
            random.choice(string.ascii_uppercase) for x in range(10))
        tmpdir = os.environ.get('TEMPDIR', '/tmp') + '/kilosort-tmp-' + code

        try:
            recording = SFMdaRecordingExtractor(self.recording_dir)
            if len(self.channels) > 0:
                recording = se.SubRecordingExtractor(
                    parent_recording=recording, channel_ids=self.channels)
            if not os.path.exists(tmpdir):
                os.mkdir(tmpdir)
            sorting = kilosort_helper(recording=recording,
                                      tmpdir=tmpdir,
                                      detect_sign=self.detect_sign,
                                      adjacency_radius=self.adjacency_radius,
                                      detect_threshold=self.detect_threshold,
                                      merge_thresh=self.merge_thresh,
                                      freq_min=self.freq_min,
                                      freq_max=self.freq_max,
                                      pc_per_chan=self.pc_per_chan)
            SFMdaSortingExtractor.write_sorting(sorting=sorting,
                                                save_path=self.firings_out)
        except:
            if os.path.exists(tmpdir):
                if not keep_temp_files:
                    shutil.rmtree(tmpdir)
            raise
        if not getattr(self, '_keep_temp_files', False):
            shutil.rmtree(tmpdir)