Exemplo n.º 1
0
 def run(self):
     sorting=si.MdaSortingExtractor(firings_file=self.firings)
     sorting_true=si.MdaSortingExtractor(firings_file=self.firings_true)
     if len(self.units_true)>0:
         sorting_true=si.SubSortingExtractor(parent_sorting=sorting_true,unit_ids=self.units_true)
     SC=st.comparison.SortingComparison(sorting_true,sorting)
     df=sw.SortingComparisonTable(comparison=SC).getDataframe()
     json=df.transpose().to_dict()
     html=df.to_html(index=False)
     _write_json_file(json,self.json_out)
     _write_json_file(html,self.html_out)
Exemplo n.º 2
0
def spikeforest_sort(
        recording_dirname, # The recording extractor
        sorter,
        sorting_params,
        _force_run=False,
        _force_save=False
    ):
    
    recording_signature=kb.computeDirHash(recording_dirname)
    signature_obj=dict(
        sorter_name=sorter.name,
        sorter_version=sorter.version,
        recording=recording_signature,
        sorting_params=sorting_params
    )
    if not _force_run:
        print('Looking up in cache...')
        firings=kb.realizeFile(key=signature_obj)
        if firings:
            print('Found')
            if _force_save:
                print('Saving')
                kb.saveFile(fname=firings,key=signature_obj)
            return si.MdaSortingExtractor(firings_file=firings)
    
    recording=si.MdaRecordingExtractor(recording_dirname)
    sorting=sorter(recording=recording,**sorting_params)
    
    si.MdaSortingExtractor.writeSorting(sorting=sorting,save_path='tmp_firings.mda')
    kb.saveFile(fname='tmp_firings.mda',key=signature_obj)

    return sorting
Exemplo n.º 3
0
 def run(self):
     recording=si.MdaRecordingExtractor(dataset_directory=self.recording_dir,download=False)
     if len(self.channels)>0:
         recording=si.SubRecordingExtractor(parent_recording=recording,channel_ids=self.channels)
     sorting=si.MdaSortingExtractor(firings_file=self.firings)
     sw.CrossCorrelogramsWidget(samplerate=recording.getSamplingFrequency(),sorting=sorting).plot()
     fname=save_plot(self.plot_out)
Exemplo n.º 4
0
 def run(self):
     recording=si.MdaRecordingExtractor(dataset_directory=self.recording_dir)
     if len(self.channels)>0:
         recording=si.SubRecordingExtractor(parent_recording=recording,channel_ids=self.channels)
     sorting=si.MdaSortingExtractor(firings_file=self.firings)
     sw.UnitWaveformsWidget(recording=recording,sorting=sorting).plot()
     fname=save_plot(self.plot_out)
Exemplo n.º 5
0
 def run(self):
   R0=si.MdaRecordingExtractor(dataset_directory=self.recording_dir,download=True)
   if (self.channel_ids) and (len(self.channel_ids)>0):
     R0=si.SubRecordingExtractor(parent_recording=R0,channel_ids=self.channel_ids)
   recording=sw.lazyfilters.bandpass_filter(recording=R0,freq_min=300,freq_max=6000)
   sorting=si.MdaSortingExtractor(firings_file=self.firings)
   ef=int(1e6)
   recording_sub=si.SubRecordingExtractor(parent_recording=recording,start_frame=0,end_frame=ef)
   recording_sub=MemoryRecordingExtractor(parent_recording=recording_sub)
   sorting_sub=si.SubSortingExtractor(parent_sorting=sorting,start_frame=0,end_frame=ef)
   unit_ids=self.unit_ids
   if (not unit_ids) or (len(unit_ids)==0):
     unit_ids=sorting.getUnitIds()
 
   channel_noise_levels=compute_channel_noise_levels(recording=recording)
   print('computing templates...')
   templates=compute_unit_templates(recording=recording_sub,sorting=sorting_sub,unit_ids=unit_ids)
   print('.')
   ret=[]
   for i,unit_id in enumerate(unit_ids):
     template=templates[i]
     info0=dict()
     info0['unit_id']=int(unit_id)
     info0['snr']=compute_template_snr(template,channel_noise_levels)
     peak_channel_index=np.argmax(np.max(np.abs(template),axis=1))
     info0['peak_channel']=int(recording.getChannelIds()[peak_channel_index])
     train=sorting.getUnitSpikeTrain(unit_id=unit_id)
     info0['num_events']=int(len(train))
     info0['firing_rate']=float(len(train)/(recording.getNumFrames()/recording.getSamplingFrequency()))
     ret.append(info0)
   write_json_file(self.json_out,ret)
Exemplo n.º 6
0
 def getSortingTrue(self):
   D2=mlp.readDir(self._kbucket_path)
   if 'firings_true.mda' in D2['files']:
     ret=si.MdaSortingExtractor(firings_file=self._kbucket_path+'/firings_true.mda')
   else:
     ret=None
   return ret
Exemplo n.º 7
0
 def test_mda_extractor(self):
     path1 = self.test_dir + '/mda'
     path2 = path1 + '/firings_true.mda'
     se.MdaRecordingExtractor.write_recording(self.RX, path1)
     se.MdaSortingExtractor.write_sorting(self.SX, path2)
     RX_mda = se.MdaRecordingExtractor(path1)
     SX_mda = se.MdaSortingExtractor(path2)
     self._check_recording_return_types(RX_mda)
     self._check_recordings_equal(self.RX, RX_mda)
     self._check_sorting_return_types(SX_mda)
     self._check_sortings_equal(self.SX, SX_mda)
Exemplo n.º 8
0
    def get_result_from_folder(output_folder: Union[str, Path]):
        output_folder = Path(output_folder)
        tmpdir = output_folder / 'tmp'

        result_fname = str(tmpdir / 'firings.mda')
        samplerate_fname = str(tmpdir / 'samplerate.txt')
        with open(samplerate_fname, 'r') as f:
            samplerate = float(f.read())

        sorting = se.MdaSortingExtractor(file_path=result_fname, sampling_frequency=samplerate)

        return sorting
Exemplo n.º 9
0
 def run(self):
   R0=si.MdaRecordingExtractor(dataset_directory=self.recording_dir,download=True)
   R=st.filters.bandpass_filter(recording=R0,freq_min=300,freq_max=6000)
   S=si.MdaSortingExtractor(firings_file=self.firings)
   channels=R.getChannelIds()
   if len(channels)>20:
     channels=channels[0:20]
   units=S.getUnitIds()
   if len(units)>20:
     units=units[::int(len(units)/20)]
   sw.UnitWaveformsWidget(recording=R,sorting=S,channels=channels,unit_ids=units).plot()
   save_plot(self.jpg_out)
Exemplo n.º 10
0
def yass_example(download=True, set_id=1):
    if set_id in range(1, 7):
        dsdir = 'kbucket://15734439d8cf/groundtruth/visapy_mea/set{}'.format(
            set_id)
        IX = se.MdaRecordingExtractor(dataset_directory=dsdir,
                                      download=download)
        path1 = os.path.join(dsdir, 'firings_true.mda')
        print(path1)
        OX = se.MdaSortingExtractor(path1)
        return (IX, OX)
    else:
        raise Exception(
            'Invalid ID for yass_example {} is not betewen 1..6'.format(
                set_id))
Exemplo n.º 11
0
    def test_dump_load_multi_sub_extractor(self):
        # generate dumpable formats
        path1 = self.test_dir + '/mda'
        path2 = path1 + '/firings_true.mda'
        se.MdaRecordingExtractor.write_recording(self.RX, path1)
        se.MdaSortingExtractor.write_sorting(self.SX, path2)
        RX_mda = se.MdaRecordingExtractor(path1)
        SX_mda = se.MdaSortingExtractor(path2)

        RX_multi_chan = se.MultiRecordingChannelExtractor(recordings=[RX_mda, RX_mda, RX_mda])
        check_dumping(RX_multi_chan)
        RX_multi_time = se.MultiRecordingTimeExtractor(recordings=[RX_mda, RX_mda, RX_mda], )
        check_dumping(RX_multi_time)
        RX_multi_chan = se.SubRecordingExtractor(RX_mda, channel_ids=[0, 1])
        check_dumping(RX_multi_chan)

        SX_sub = se.SubSortingExtractor(SX_mda, unit_ids=[1, 2])
        check_dumping(SX_sub)
        SX_multi = se.MultiSortingExtractor(sortings=[SX_mda, SX_mda, SX_mda])
        check_dumping(SX_multi)
Exemplo n.º 12
0
  def run(self):
    import spikewidgets as sw
    
    R0=si.MdaRecordingExtractor(dataset_directory=self.recording_dir,download=True)
    if (self.channel_ids) and (len(self.channel_ids)>0):
      R0=si.SubRecordingExtractor(parent_recording=R0,channel_ids=self.channel_ids)
    
    recording = R0
    # recording=sw.lazyfilters.bandpass_filter(recording=R0,freq_min=300,freq_max=6000)

    sorting=si.MdaSortingExtractor(firings_file=self.firings)
    unit_ids=self.unit_ids
    if (not unit_ids) or (len(unit_ids)==0):
      unit_ids=sorting.getUnitIds()
  
    channel_noise_levels=compute_channel_noise_levels(recording=recording)

    # No longer use subset to compute the templates
    templates=compute_unit_templates(recording=recording,sorting=sorting,unit_ids=unit_ids,max_num=100)

    ret=[]
    for i,unit_id in enumerate(unit_ids):
      template=templates[i]
      max_p2p_amps_on_channels=np.max(template,axis=1)-np.min(template,axis=1)
      peak_channel_index=np.argmax(max_p2p_amps_on_channels)
      peak_channel=recording.getChannelIds()[peak_channel_index]
      R1=si.SubRecordingExtractor(parent_recording=recording,channel_ids=[peak_channel_index])
      R1f=sw.lazyfilters.bandpass_filter(recording=R1,freq_min=300,freq_max=6000)
      templates2=compute_unit_templates(recording=R1f,sorting=sorting,unit_ids=[unit_id],max_num=100)
      template2=templates2[0]
      info0=dict()
      info0['unit_id']=int(unit_id)
      info0['snr']=np.max(np.abs(template2))/channel_noise_levels[peak_channel_index]
      #info0['snr']=compute_template_snr(template,channel_noise_levels)
      #peak_channel_index=np.argmax(np.max(np.abs(template),axis=1))
      info0['peak_channel']=int(recording.getChannelIds()[peak_channel])
      train=sorting.getUnitSpikeTrain(unit_id=unit_id)
      info0['num_events']=int(len(train))
      info0['firing_rate']=float(len(train)/(recording.getNumFrames()/recording.getSamplingFrequency()))
      ret.append(info0)
    write_json_file(self.json_out,ret)
Exemplo n.º 13
0
 def get_result_from_folder(output_folder):
     sorting = se.MdaSortingExtractor(str(output_folder / 'firings.mda'))
     return sorting
Exemplo n.º 14
0
 def sorting(self):
     return si.MdaSortingExtractor(firings_file=self._obj['firings'])
Exemplo n.º 15
0
 def sortingTrue(self):
     return si.MdaSortingExtractor(firings_file=self.directory() +
                                   '/firings_true.mda')