예제 #1
0
 def run(self):
     code=''.join(random.choice(string.ascii_uppercase) for x in range(10))
     tmpdir=os.environ.get('TEMPDIR','/tmp')+'/ironclust-tmp-'+code
     
     num_workers=os.environ.get('NUM_WORKERS',2)
         
     try:
         recording=si.MdaRecordingExtractor(self.recording_dir)
         if len(self.channels)>0:
           recording=si.SubRecordingExtractor(parent_recording=recording,channel_ids=self.channels)
         if not os.path.exists(tmpdir):
             os.mkdir(tmpdir)
         sorting=sf.sorters.spyking_circus(
             recording=recording,
             output_folder=tmpdir,
             probe_file=None,
             file_name=None,
             detect_sign=self.detect_sign,
             adjacency_radius=self.adjacency_radius,
             spike_thresh=self.spike_thresh,
             template_width_ms=self.template_width_ms,
             filter=self.filter,
             merge_spikes=True,
             n_cores=num_workers,
             electrode_dimensions=None,
             whitening_max_elts=self.whitening_max_elts,
             clustering_max_elts=self.clustering_max_elts
         )
         si.MdaSortingExtractor.writeSorting(sorting=sorting,save_path=self.firings_out)
     except:
         if os.path.exists(tmpdir):
             shutil.rmtree(tmpdir)
         raise
     shutil.rmtree(tmpdir)
예제 #2
0
    def run(self):
        code = ''.join(random.choice(string.ascii_uppercase)
                       for x in range(10))
        tmpdir = os.environ.get('TEMPDIR', '/tmp')+'/yass-tmp-'+code

        #num_workers = os.environ.get('NUM_WORKERS', 1)
        #print('num_workers: {}'.format(num_workers))
        try:
            recording = se.MdaRecordingExtractor(self.recording_dir)
            if len(self.channels) > 0:
                recording = se.SubRecordingExtractor(
                    parent_recording=recording, channel_ids=self.channels)
            if not os.path.exists(tmpdir):
                os.mkdir(tmpdir)
            sorting, yaml_file = yass_helper(
                recording=recording,
                output_folder=tmpdir,
                probe_file=None,
                file_name=None,
                detect_sign=self.detect_sign,
                adjacency_radius=self.adjacency_radius,
                template_width_ms=self.template_width_ms,
                filter=self.filter)
            se.MdaSortingExtractor.writeSorting(
                sorting=sorting, save_path=self.firings_out)
            #shutil.copyfile(yaml_file, self.paramfile_out)
        except:
            if os.path.exists(tmpdir):
                shutil.rmtree(tmpdir)
            raise
        if not getattr(self, '_keep_temp_files', False):
            shutil.rmtree(tmpdir)
예제 #3
0
def test_remove_bad_channels():
    rec, sort = se.example_datasets.toy_example(dump_folder='test', dumpable=True, duration=2, num_channels=4, seed=0)
    rec_rm = remove_bad_channels(rec, bad_channel_ids=[0])
    assert 0 not in rec_rm.get_channel_ids()

    rec_rm = remove_bad_channels(rec, bad_channel_ids=[1, 2])
    assert 1 not in rec_rm.get_channel_ids() and 2 not in rec_rm.get_channel_ids()

    check_dumping(rec_rm)
    shutil.rmtree('test')

    timeseries = np.random.randn(4, 60000)
    timeseries[1] = 10 * timeseries[1]

    rec_np = se.NumpyRecordingExtractor(timeseries=timeseries, sampling_frequency=30000)
    rec_np.set_channel_locations(np.ones((rec_np.get_num_channels(), 2)))
    se.MdaRecordingExtractor.write_recording(rec_np, 'test')
    rec = se.MdaRecordingExtractor('test')
    rec_rm = remove_bad_channels(rec, bad_channel_ids=None, bad_threshold=2)
    assert 1 not in rec_rm.get_channel_ids()
    check_dumping(rec_rm)

    rec_rm = remove_bad_channels(rec, bad_channel_ids=None, bad_threshold=2, seconds=0.1)
    assert 1 not in rec_rm.get_channel_ids()
    check_dumping(rec_rm)

    rec_rm = remove_bad_channels(rec, bad_channel_ids=None, bad_threshold=2, seconds=10)
    assert 1 not in rec_rm.get_channel_ids()
    check_dumping(rec_rm)
    shutil.rmtree('test')
예제 #4
0
 def run(self):
     ironclust_src=os.environ.get('IRONCLUST_SRC',None)
     if not ironclust_src:
         raise Exception('Environment variable not set: IRONCLUST_SRC')
     code=''.join(random.choice(string.ascii_uppercase) for x in range(10))
     tmpdir=os.environ.get('TEMPDIR','/tmp')+'/ironclust-tmp-'+code
         
     try:
         recording=si.MdaRecordingExtractor(self.recording_dir)
         if len(self.channels)>0:
           recording=si.SubRecordingExtractor(parent_recording=recording,channel_ids=self.channels)
         if not os.path.exists(tmpdir):
             os.mkdir(tmpdir)
         sorting=sf.sorters.ironclust(
             recording=recording,
             tmpdir=tmpdir, ## TODO
             detect_sign=self.detect_sign,
             adjacency_radius=self.adjacency_radius,
             detect_threshold=self.detect_threshold,
             merge_thresh=self.merge_thresh,
             freq_min=self.freq_min,
             freq_max=self.freq_max,
             pc_per_chan=self.pc_per_chan,
             prm_template_name=self.prm_template_name,
             ironclust_src=ironclust_src
         )
         si.MdaSortingExtractor.writeSorting(sorting=sorting,save_path=self.firings_out)
     except:
         if os.path.exists(tmpdir):
             shutil.rmtree(tmpdir)
         raise
     shutil.rmtree(tmpdir)
예제 #5
0
 def run(self):
     recording=si.MdaRecordingExtractor(dataset_directory=self.recording_dir,download=False)
     if len(self.channels)>0:
         recording=si.SubRecordingExtractor(parent_recording=recording,channel_ids=self.channels)
     sorting=si.MdaSortingExtractor(firings_file=self.firings)
     sw.CrossCorrelogramsWidget(samplerate=recording.getSamplingFrequency(),sorting=sorting).plot()
     fname=save_plot(self.plot_out)
예제 #6
0
 def run(self):
   R0=si.MdaRecordingExtractor(dataset_directory=self.recording_dir,download=True)
   if (self.channel_ids) and (len(self.channel_ids)>0):
     R0=si.SubRecordingExtractor(parent_recording=R0,channel_ids=self.channel_ids)
   recording=sw.lazyfilters.bandpass_filter(recording=R0,freq_min=300,freq_max=6000)
   sorting=si.MdaSortingExtractor(firings_file=self.firings)
   ef=int(1e6)
   recording_sub=si.SubRecordingExtractor(parent_recording=recording,start_frame=0,end_frame=ef)
   recording_sub=MemoryRecordingExtractor(parent_recording=recording_sub)
   sorting_sub=si.SubSortingExtractor(parent_sorting=sorting,start_frame=0,end_frame=ef)
   unit_ids=self.unit_ids
   if (not unit_ids) or (len(unit_ids)==0):
     unit_ids=sorting.getUnitIds()
 
   channel_noise_levels=compute_channel_noise_levels(recording=recording)
   print('computing templates...')
   templates=compute_unit_templates(recording=recording_sub,sorting=sorting_sub,unit_ids=unit_ids)
   print('.')
   ret=[]
   for i,unit_id in enumerate(unit_ids):
     template=templates[i]
     info0=dict()
     info0['unit_id']=int(unit_id)
     info0['snr']=compute_template_snr(template,channel_noise_levels)
     peak_channel_index=np.argmax(np.max(np.abs(template),axis=1))
     info0['peak_channel']=int(recording.getChannelIds()[peak_channel_index])
     train=sorting.getUnitSpikeTrain(unit_id=unit_id)
     info0['num_events']=int(len(train))
     info0['firing_rate']=float(len(train)/(recording.getNumFrames()/recording.getSamplingFrequency()))
     ret.append(info0)
   write_json_file(self.json_out,ret)
예제 #7
0
def real(name='franklab_tetrode', download=True):
    if name == 'franklab_tetrode':
        dsdir = 'kbucket://b5ecdf1474c5/datasets/neuron_paper/franklab_tetrode'
        IX = se.MdaRecordingExtractor(dir_path=dsdir)
        return (IX, None)
    else:
        raise Exception('Unrecognized name for real dataset: ' + name)
예제 #8
0
    def run(self):
        code = ''.join(
            random.choice(string.ascii_uppercase) for x in range(10))
        tmpdir = os.environ.get('TEMPDIR', '/tmp') + '/kilosort-tmp-' + code

        try:
            recording = si.MdaRecordingExtractor(self.recording_dir)
            if len(self.channels) > 0:
                recording = si.SubRecordingExtractor(
                    parent_recording=recording, channel_ids=self.channels)
            if not os.path.exists(tmpdir):
                os.mkdir(tmpdir)
            sorting = sorters.kilosort(recording=recording,
                                       tmpdir=tmpdir,
                                       detect_sign=self.detect_sign,
                                       adjacency_radius=self.adjacency_radius,
                                       detect_threshold=self.detect_threshold,
                                       merge_thresh=self.merge_thresh,
                                       freq_min=self.freq_min,
                                       freq_max=self.freq_max,
                                       pc_per_chan=self.pc_per_chan)
            si.MdaSortingExtractor.writeSorting(sorting=sorting,
                                                save_path=self.firings_out)
        except:
            if os.path.exists(tmpdir):
                shutil.rmtree(tmpdir)
            raise
        shutil.rmtree(tmpdir)
예제 #9
0
 def run(self):
   ret={}
   recording=si.MdaRecordingExtractor(dataset_directory=self.recording_dir,download=False)
   ret['samplerate']=recording.getSamplingFrequency()
   ret['num_channels']=len(recording.getChannelIds())
   ret['duration_sec']=recording.getNumFrames()/ret['samplerate']
   write_json_file(self.json_out,ret)
예제 #10
0
def spikeforest_sort(
        recording_dirname, # The recording extractor
        sorter,
        sorting_params,
        _force_run=False,
        _force_save=False
    ):
    
    recording_signature=kb.computeDirHash(recording_dirname)
    signature_obj=dict(
        sorter_name=sorter.name,
        sorter_version=sorter.version,
        recording=recording_signature,
        sorting_params=sorting_params
    )
    if not _force_run:
        print('Looking up in cache...')
        firings=kb.realizeFile(key=signature_obj)
        if firings:
            print('Found')
            if _force_save:
                print('Saving')
                kb.saveFile(fname=firings,key=signature_obj)
            return si.MdaSortingExtractor(firings_file=firings)
    
    recording=si.MdaRecordingExtractor(recording_dirname)
    sorting=sorter(recording=recording,**sorting_params)
    
    si.MdaSortingExtractor.writeSorting(sorting=sorting,save_path='tmp_firings.mda')
    kb.saveFile(fname='tmp_firings.mda',key=signature_obj)

    return sorting
예제 #11
0
 def run(self):
     recording=si.MdaRecordingExtractor(dataset_directory=self.recording_dir)
     if len(self.channels)>0:
         recording=si.SubRecordingExtractor(parent_recording=recording,channel_ids=self.channels)
     sorting=si.MdaSortingExtractor(firings_file=self.firings)
     sw.UnitWaveformsWidget(recording=recording,sorting=sorting).plot()
     fname=save_plot(self.plot_out)
예제 #12
0
 def recordingExtractor(self, download=False):
     X = si.MdaRecordingExtractor(dataset_directory=self.directory(),
                                  download=download)
     if 'channels' in self._obj:
         if self._obj['channels']:
             X = si.SubRecordingExtractor(parent_recording=X,
                                          channel_ids=self._obj['channels'])
     return X
예제 #13
0
 def run(self):
   R0=si.MdaRecordingExtractor(dataset_directory=self.recording_dir,download=False)
   R=st.filters.bandpass_filter(recording=R0,freq_min=300,freq_max=6000)
   N=R.getNumFrames()
   N2=int(N/2)
   channels=R.getChannelIds()
   if len(channels)>20: channels=channels[0:20]
   sw.TimeseriesWidget(recording=R,trange=[N2-4000,N2+0],channels=channels,width=12,height=5).plot()
   save_plot(self.jpg_out)
예제 #14
0
 def createSession(self):
     recording = se.MdaRecordingExtractor(
         dataset_directory=self._recording_directory, download=False)
     recording = se.SubRecordingExtractor(parent_recording=recording,
                                          start_frame=0,
                                          end_frame=10000)
     recording = se.NumpyRecordingExtractor(
         timeseries=recording.getTraces(),
         samplerate=recording.getSamplingFrequency())
     W = SFW.TimeseriesWidget(recording=recording)
     return W
예제 #15
0
 def test_mda_extractor(self):
     path1 = self.test_dir + '/mda'
     path2 = path1 + '/firings_true.mda'
     se.MdaRecordingExtractor.write_recording(self.RX, path1)
     se.MdaSortingExtractor.write_sorting(self.SX, path2)
     RX_mda = se.MdaRecordingExtractor(path1)
     SX_mda = se.MdaSortingExtractor(path2)
     self._check_recording_return_types(RX_mda)
     self._check_recordings_equal(self.RX, RX_mda)
     self._check_sorting_return_types(SX_mda)
     self._check_sortings_equal(self.SX, SX_mda)
예제 #16
0
def create_dumpable_extractors_from_existing(folder, RX, SX):
    folder = Path(folder)

    if 'location' not in RX.get_shared_channel_property_names():
        RX.set_channel_locations(np.random.randn(RX.get_num_channels(), 2))
    se.MdaRecordingExtractor.write_recording(RX, folder)
    RX_mda = se.MdaRecordingExtractor(folder)
    se.NpzSortingExtractor.write_sorting(SX, folder / 'sorting.npz')
    SX_npz = se.NpzSortingExtractor(folder / 'sorting.npz')

    return RX_mda, SX_npz
예제 #17
0
 def run(self):
   R0=si.MdaRecordingExtractor(dataset_directory=self.recording_dir,download=True)
   R=st.filters.bandpass_filter(recording=R0,freq_min=300,freq_max=6000)
   S=si.MdaSortingExtractor(firings_file=self.firings)
   channels=R.getChannelIds()
   if len(channels)>20:
     channels=channels[0:20]
   units=S.getUnitIds()
   if len(units)>20:
     units=units[::int(len(units)/20)]
   sw.UnitWaveformsWidget(recording=R,sorting=S,channels=channels,unit_ids=units).plot()
   save_plot(self.jpg_out)
예제 #18
0
def yass_example(download=True, set_id=1):
    if set_id in range(1, 7):
        dsdir = 'kbucket://15734439d8cf/groundtruth/visapy_mea/set{}'.format(
            set_id)
        IX = se.MdaRecordingExtractor(dataset_directory=dsdir,
                                      download=download)
        path1 = os.path.join(dsdir, 'firings_true.mda')
        print(path1)
        OX = se.MdaSortingExtractor(path1)
        return (IX, OX)
    else:
        raise Exception(
            'Invalid ID for yass_example {} is not betewen 1..6'.format(
                set_id))
예제 #19
0
def toy_example(duration=10,
                num_channels=4,
                sampling_frequency=30000.0,
                K=10,
                dumpable=False,
                dump_folder=None,
                seed=None):
    upsamplefac = 13

    waveforms, geom = synthesize_random_waveforms(K=K,
                                                  M=num_channels,
                                                  average_peak_amplitude=-100,
                                                  upsamplefac=upsamplefac,
                                                  seed=seed)
    times, labels = synthesize_random_firings(
        K=K,
        duration=duration,
        sampling_frequency=sampling_frequency,
        seed=seed)
    labels = labels.astype(np.int64)
    SX = se.NumpySortingExtractor()
    SX.set_times_labels(times, labels)
    X = synthesize_timeseries(sorting=SX,
                              waveforms=waveforms,
                              noise_level=10,
                              sampling_frequency=sampling_frequency,
                              duration=duration,
                              waveform_upsamplefac=upsamplefac,
                              seed=seed)
    SX.set_sampling_frequency(sampling_frequency)

    RX = se.NumpyRecordingExtractor(timeseries=X,
                                    sampling_frequency=sampling_frequency,
                                    geom=geom)
    RX.is_filtered = True

    if dumpable:
        if dump_folder is None:
            dump_folder = 'toy_example'
        dump_folder = Path(dump_folder)

        se.MdaRecordingExtractor.write_recording(RX, dump_folder)
        RX = se.MdaRecordingExtractor(dump_folder)
        se.NpzSortingExtractor.write_sorting(SX, dump_folder / 'sorting.npz')
        SX = se.NpzSortingExtractor(dump_folder / 'sorting.npz')

    return RX, SX
예제 #20
0
 def run(self):
     recording = si.MdaRecordingExtractor(self.dataset_dir)
     num_workers = int(os.environ.get('NUM_WORKERS', -1))
     if num_workers <= 0: num_workers = None
     sorting = sf.sorters.mountainsort4(
         recording=recording,
         detect_sign=self.detect_sign,
         adjacency_radius=self.adjacency_radius,
         freq_min=self.freq_min,
         freq_max=self.freq_max,
         whiten=self.whiten,
         clip_size=self.clip_size,
         detect_threshold=self.detect_threshold,
         detect_interval=self.detect_interval,
         noise_overlap_threshold=self.noise_overlap_threshold,
         num_workers=num_workers)
     si.MdaSortingExtractor.writeSorting(sorting=sorting,
                                         save_path=self.firings_out)
예제 #21
0
    def run(self):
        import spikeextractors as se
        import spiketoolkit as st
        import ml_ms4alg

        print('MountainSort4......')
        recording = se.MdaRecordingExtractor(self.recording_dir)
        num_workers = os.environ.get('NUM_WORKERS', None)
        if num_workers:
            num_workers = int(num_workers)

        # Bandpass filter
        if self.freq_min or self.freq_max:
            recording = st.preprocessing.bandpass_filter(
                recording=recording,
                freq_min=self.freq_min,
                freq_max=self.freq_max)

        # Whiten
        if self.whiten:
            recording = st.preprocessing.whiten(recording=recording)

        # Sort
        sorting = ml_ms4alg.mountainsort4(
            recording=recording,
            detect_sign=self.detect_sign,
            adjacency_radius=self.adjacency_radius,
            clip_size=self.clip_size,
            detect_threshold=self.detect_threshold,
            detect_interval=self.detect_interval,
            num_workers=num_workers)

        # Curate
        # if self.noise_overlap_threshold is not None:
        #    sorting=ml_ms4alg.mountainsort4_curation(
        #      recording=recording,
        #      sorting=sorting,
        #      noise_overlap_threshold=self.noise_overlap_threshold
        #    )

        se.MdaSortingExtractor.writeSorting(sorting=sorting,
                                            save_path=self.firings_out)
예제 #22
0
def create_dumpable_extractors(folder,
                               duration=10,
                               num_channels=4,
                               sampling_frequency=30000.0,
                               K=10,
                               seed=None):
    RX, SX = toy_example(duration=duration,
                         num_channels=num_channels,
                         K=K,
                         sampling_frequency=sampling_frequency,
                         seed=seed)

    folder = Path(folder)

    se.MdaRecordingExtractor.write_recording(RX, folder)
    RX_mda = se.MdaRecordingExtractor(folder)
    se.NpzSortingExtractor.write_sorting(SX, folder / 'sorting.npz')
    SX_npz = se.NpzSortingExtractor(folder / 'sorting.npz')

    return RX_mda, SX_npz
    def test_dump_load_multi_sub_extractor(self):
        # generate dumpable formats
        path1 = self.test_dir + '/mda'
        path2 = path1 + '/firings_true.mda'
        se.MdaRecordingExtractor.write_recording(self.RX, path1)
        se.MdaSortingExtractor.write_sorting(self.SX, path2)
        RX_mda = se.MdaRecordingExtractor(path1)
        SX_mda = se.MdaSortingExtractor(path2)

        RX_multi_chan = se.MultiRecordingChannelExtractor(recordings=[RX_mda, RX_mda, RX_mda])
        check_dumping(RX_multi_chan)
        RX_multi_time = se.MultiRecordingTimeExtractor(recordings=[RX_mda, RX_mda, RX_mda], )
        check_dumping(RX_multi_time)
        RX_multi_chan = se.SubRecordingExtractor(RX_mda, channel_ids=[0, 1])
        check_dumping(RX_multi_chan)

        SX_sub = se.SubSortingExtractor(SX_mda, unit_ids=[1, 2])
        check_dumping(SX_sub)
        SX_multi = se.MultiSortingExtractor(sortings=[SX_mda, SX_mda, SX_mda])
        check_dumping(SX_multi)
예제 #24
0
  def run(self):
    import spikewidgets as sw
    
    R0=si.MdaRecordingExtractor(dataset_directory=self.recording_dir,download=True)
    if (self.channel_ids) and (len(self.channel_ids)>0):
      R0=si.SubRecordingExtractor(parent_recording=R0,channel_ids=self.channel_ids)
    
    recording = R0
    # recording=sw.lazyfilters.bandpass_filter(recording=R0,freq_min=300,freq_max=6000)

    sorting=si.MdaSortingExtractor(firings_file=self.firings)
    unit_ids=self.unit_ids
    if (not unit_ids) or (len(unit_ids)==0):
      unit_ids=sorting.getUnitIds()
  
    channel_noise_levels=compute_channel_noise_levels(recording=recording)

    # No longer use subset to compute the templates
    templates=compute_unit_templates(recording=recording,sorting=sorting,unit_ids=unit_ids,max_num=100)

    ret=[]
    for i,unit_id in enumerate(unit_ids):
      template=templates[i]
      max_p2p_amps_on_channels=np.max(template,axis=1)-np.min(template,axis=1)
      peak_channel_index=np.argmax(max_p2p_amps_on_channels)
      peak_channel=recording.getChannelIds()[peak_channel_index]
      R1=si.SubRecordingExtractor(parent_recording=recording,channel_ids=[peak_channel_index])
      R1f=sw.lazyfilters.bandpass_filter(recording=R1,freq_min=300,freq_max=6000)
      templates2=compute_unit_templates(recording=R1f,sorting=sorting,unit_ids=[unit_id],max_num=100)
      template2=templates2[0]
      info0=dict()
      info0['unit_id']=int(unit_id)
      info0['snr']=np.max(np.abs(template2))/channel_noise_levels[peak_channel_index]
      #info0['snr']=compute_template_snr(template,channel_noise_levels)
      #peak_channel_index=np.argmax(np.max(np.abs(template),axis=1))
      info0['peak_channel']=int(recording.getChannelIds()[peak_channel])
      train=sorting.getUnitSpikeTrain(unit_id=unit_id)
      info0['num_events']=int(len(train))
      info0['firing_rate']=float(len(train)/(recording.getNumFrames()/recording.getSamplingFrequency()))
      ret.append(info0)
    write_json_file(self.json_out,ret)
예제 #25
0
 def run(self):
     recording=si.MdaRecordingExtractor(self.recording_dir)
     if len(self.channels)>0:
       recording=si.SubRecordingExtractor(parent_recording=recording,channel_ids=self.channels)
     num_workers=os.environ.get('NUM_WORKERS',None)
     if num_workers:
         num_workers=int(num_workers)
     sorting=sf.sorters.mountainsort4(
         recording=recording,
         detect_sign=self.detect_sign,
         adjacency_radius=self.adjacency_radius,
         freq_min=self.freq_min,
         freq_max=self.freq_max,
         whiten=self.whiten,
         clip_size=self.clip_size,
         detect_threshold=self.detect_threshold,
         detect_interval=self.detect_interval,
         noise_overlap_threshold=self.noise_overlap_threshold,
         num_workers=num_workers
     )
     si.MdaSortingExtractor.writeSorting(sorting=sorting,save_path=self.firings_out)
예제 #26
0
    def run(self):
        ironclust_path = os.environ.get('IRONCLUST_PATH', None)
        if not ironclust_path:
            raise Exception('Environment variable not set: IRONCLUST_PATH')

        code = ''.join(
            random.choice(string.ascii_uppercase) for x in range(10))
        tmpdir = os.environ.get('TEMPDIR', '/tmp') + '/ironclust-tmp-' + code

        try:
            recording = se.MdaRecordingExtractor(self.recording_dir)
            params = read_dataset_params(self.recording_dir)
            if len(self.channels) > 0:
                recording = se.SubRecordingExtractor(
                    parent_recording=recording, channel_ids=self.channels)
            if not os.path.exists(tmpdir):
                os.mkdir(tmpdir)
            sorting = ironclust_helper(
                recording=recording,
                tmpdir=tmpdir,
                detect_sign=self.detect_sign,
                adjacency_radius=self.adjacency_radius,
                detect_threshold=self.detect_threshold,
                merge_thresh=self.merge_thresh,
                freq_min=self.freq_min,
                freq_max=self.freq_max,
                pc_per_chan=self.pc_per_chan,
                prm_template_name=self.prm_template_name,
                ironclust_path=ironclust_path,
                params=params,
            )
            se.MdaSortingExtractor.writeSorting(sorting=sorting,
                                                save_path=self.firings_out)
        except:
            if os.path.exists(tmpdir):
                if not getattr(self, '_keep_temp_files', False):
                    shutil.rmtree(tmpdir)
            raise
        if not getattr(self, '_keep_temp_files', False):
            shutil.rmtree(tmpdir)
예제 #27
0
 def getRecording(self,download=True):
   ret=si.MdaRecordingExtractor(dataset_directory=self._kbucket_path,download=download)
   return ret
예제 #28
0
def toy_example(duration: float = 10.,
                num_channels: int = 4,
                sampling_frequency: float = 30000.,
                K: int = 10,
                dumpable: bool = False,
                dump_folder: Optional[Union[str, Path]] = None,
                seed: Optional[int] = None):
    """
    Create toy recording and sorting extractors.

    Parameters
    ----------
    duration: float
        Duration in s (default 10)
    num_channels: int
        Number of channels (default 4)
    sampling_frequency: float
        Sampling frequency (default 30000)
    K: int
        Number of units (default 10)
    dumpable: bool
        If True, objects are dumped to file and become 'dumpable'
    dump_folder: str or Path
        Path to dump folder (if None, 'test' is used
    seed: int
        Seed for random initialization

    Returns
    -------
    recording: RecordingExtractor
        The output recording extractor. If dumpable is False it's a NumpyRecordingExtractor, otherwise it's an
        MdaRecordingExtractor
    sorting: SortingExtractor
        The output sorting extractor. If dumpable is False it's a NumpyRecordingExtractor, otherwise it's an
        NpzSortingExtractor
    """
    upsamplefac = 13
    waveforms, geom = synthesize_random_waveforms(K=K,
                                                  M=num_channels,
                                                  average_peak_amplitude=-100,
                                                  upsamplefac=upsamplefac,
                                                  seed=seed)
    times, labels = synthesize_random_firings(
        K=K,
        duration=duration,
        sampling_frequency=sampling_frequency,
        seed=seed)
    labels = labels.astype(np.int64)
    SX = se.NumpySortingExtractor()
    SX.set_times_labels(times, labels)
    X = synthesize_timeseries(sorting=SX,
                              waveforms=waveforms,
                              noise_level=10,
                              sampling_frequency=sampling_frequency,
                              duration=duration,
                              waveform_upsamplefac=upsamplefac,
                              seed=seed)
    SX.set_sampling_frequency(sampling_frequency)

    RX = se.NumpyRecordingExtractor(timeseries=X,
                                    sampling_frequency=sampling_frequency,
                                    geom=geom)
    RX.is_filtered = True

    if dumpable:
        if dump_folder is None:
            dump_folder = 'toy_example'
        dump_folder = Path(dump_folder)

        se.MdaRecordingExtractor.write_recording(RX, dump_folder)
        RX = se.MdaRecordingExtractor(dump_folder)
        se.NpzSortingExtractor.write_sorting(SX, dump_folder / 'sorting.npz')
        SX = se.NpzSortingExtractor(dump_folder / 'sorting.npz')

    return RX, SX