示例#1
0
 def run(self):
   R0=si.MdaRecordingExtractor(dataset_directory=self.recording_dir,download=True)
   if (self.channel_ids) and (len(self.channel_ids)>0):
     R0=si.SubRecordingExtractor(parent_recording=R0,channel_ids=self.channel_ids)
   recording=sw.lazyfilters.bandpass_filter(recording=R0,freq_min=300,freq_max=6000)
   sorting=si.MdaSortingExtractor(firings_file=self.firings)
   ef=int(1e6)
   recording_sub=si.SubRecordingExtractor(parent_recording=recording,start_frame=0,end_frame=ef)
   recording_sub=MemoryRecordingExtractor(parent_recording=recording_sub)
   sorting_sub=si.SubSortingExtractor(parent_sorting=sorting,start_frame=0,end_frame=ef)
   unit_ids=self.unit_ids
   if (not unit_ids) or (len(unit_ids)==0):
     unit_ids=sorting.getUnitIds()
 
   channel_noise_levels=compute_channel_noise_levels(recording=recording)
   print('computing templates...')
   templates=compute_unit_templates(recording=recording_sub,sorting=sorting_sub,unit_ids=unit_ids)
   print('.')
   ret=[]
   for i,unit_id in enumerate(unit_ids):
     template=templates[i]
     info0=dict()
     info0['unit_id']=int(unit_id)
     info0['snr']=compute_template_snr(template,channel_noise_levels)
     peak_channel_index=np.argmax(np.max(np.abs(template),axis=1))
     info0['peak_channel']=int(recording.getChannelIds()[peak_channel_index])
     train=sorting.getUnitSpikeTrain(unit_id=unit_id)
     info0['num_events']=int(len(train))
     info0['firing_rate']=float(len(train)/(recording.getNumFrames()/recording.getSamplingFrequency()))
     ret.append(info0)
   write_json_file(self.json_out,ret)
示例#2
0
    def run(self, payload, next_elem):
        spif_params_dict = {}
        probe_file = None
        for param in self.param_list:
            if param["name"] == "probe_path":
                probe_file = param["value"]
            elif param["name"] == "channel_map":
                channel_map = param["value"]
            elif param["name"] == "channel_groups":
                channel_groups = param["value"]
            else:
                spif_params_dict[param["name"]] = param["value"]

        recording = self._spif_class(**spif_params_dict)

        if probe_file:
            recording = recording.load_probe_file(probe_file, channel_map,
                                                  channel_groups)
        else:
            if channel_map:
                assert np.all([
                    chan in channel_map
                    for chan in recording.get_channel_ids()
                ]), ("all channel_ids in "
                     "'channel_map' must be in recording channel ids")
                recording = se.SubRecordingExtractor(recording,
                                                     channel_ids=channel_map)
            if channel_groups:
                recording.set_channel_groups(recording.get_channel_ids(),
                                             channel_groups)

        return recording
示例#3
0
文件: yass1.py 项目: yger/spikeforest
    def run(self):
        code = ''.join(
            random.choice(string.ascii_uppercase) for x in range(10))
        tmpdir = os.environ.get('TEMPDIR', '/tmp') + '/yass-tmp-' + code

        # num_workers = os.environ.get('NUM_WORKERS', 1)
        # print('num_workers: {}'.format(num_workers))
        try:
            recording = SFMdaRecordingExtractor(self.recording_dir)
            if len(self.channels) > 0:
                recording = se.SubRecordingExtractor(
                    parent_recording=recording, channel_ids=self.channels)
            if not os.path.exists(tmpdir):
                os.mkdir(tmpdir)
            sorting, _ = yass_helper(recording=recording,
                                     output_folder=tmpdir,
                                     probe_file=None,
                                     file_name=None,
                                     detect_sign=self.detect_sign,
                                     adjacency_radius=self.adjacency_radius,
                                     template_width_ms=self.template_width_ms,
                                     filter=self.filter)
            SFMdaSortingExtractor.write_sorting(sorting=sorting,
                                                save_path=self.firings_out)
            # shutil.copyfile(yaml_file, self.paramfile_out)
        except:
            if os.path.exists(tmpdir):
                # shutil.rmtree(tmpdir)
                print('not deleted tmpdir1')
            raise
        if not getattr(self, '_keep_temp_files', False):
            # shutil.rmtree(tmpdir)
            print('not deleted tmpdir2')
示例#4
0
def old_fetch_average_waveform_plot_data(recording_object, sorting_object,
                                         unit_id):
    import labbox_ephys as le
    R = le.LabboxEphysRecordingExtractor(recording_object)
    S = le.LabboxEphysSortingExtractor(sorting_object)

    start_frame = 0
    end_frame = R.get_sampling_frequency() * 30
    R0 = se.SubRecordingExtractor(parent_recording=R,
                                  start_frame=start_frame,
                                  end_frame=end_frame)
    S0 = se.SubSortingExtractor(parent_sorting=S,
                                start_frame=start_frame,
                                end_frame=end_frame)

    times0 = S0.get_unit_spike_train(unit_id=unit_id)
    if len(times0) == 0:
        # no waveforms found
        return dict(channel_id=None, average_waveform=None)
    try:
        average_waveform = st.postprocessing.get_unit_templates(
            recording=R0, sorting=S0, unit_ids=[unit_id])[0]
    except:
        raise Exception(f'Error getting unit templates for unit {unit_id}')

    channel_maximums = np.max(np.abs(average_waveform), axis=1)
    maxchan_index = np.argmax(channel_maximums)
    maxchan_id = R0.get_channel_ids()[maxchan_index]

    return dict(channel_id=maxchan_id,
                average_waveform=average_waveform[maxchan_index, :].tolist())
示例#5
0
 def run(self):
     code=''.join(random.choice(string.ascii_uppercase) for x in range(10))
     tmpdir=os.environ.get('TEMPDIR','/tmp')+'/ironclust-tmp-'+code
     
     num_workers=os.environ.get('NUM_WORKERS',2)
         
     try:
         recording=si.MdaRecordingExtractor(self.recording_dir)
         if len(self.channels)>0:
           recording=si.SubRecordingExtractor(parent_recording=recording,channel_ids=self.channels)
         if not os.path.exists(tmpdir):
             os.mkdir(tmpdir)
         sorting=sf.sorters.spyking_circus(
             recording=recording,
             output_folder=tmpdir,
             probe_file=None,
             file_name=None,
             detect_sign=self.detect_sign,
             adjacency_radius=self.adjacency_radius,
             spike_thresh=self.spike_thresh,
             template_width_ms=self.template_width_ms,
             filter=self.filter,
             merge_spikes=True,
             n_cores=num_workers,
             electrode_dimensions=None,
             whitening_max_elts=self.whitening_max_elts,
             clustering_max_elts=self.clustering_max_elts
         )
         si.MdaSortingExtractor.writeSorting(sorting=sorting,save_path=self.firings_out)
     except:
         if os.path.exists(tmpdir):
             shutil.rmtree(tmpdir)
         raise
     shutil.rmtree(tmpdir)
示例#6
0
    def run(self):
        _keep_temp_files = True

        code = ''.join(
            random.choice(string.ascii_uppercase) for x in range(10))
        tmpdir = os.environ.get('TEMPDIR', '/tmp') + '/kilosort2-tmp-' + code

        try:
            recording = SFMdaRecordingExtractor(self.recording_dir)
            if len(self.channels) > 0:
                recording = se.SubRecordingExtractor(
                    parent_recording=recording, channel_ids=self.channels)
            if not os.path.exists(tmpdir):
                os.mkdir(tmpdir)
            sorting = kilosort2_helper(recording=recording,
                                       tmpdir=tmpdir,
                                       detect_sign=self.detect_sign,
                                       adjacency_radius=self.adjacency_radius,
                                       detect_threshold=self.detect_threshold,
                                       merge_thresh=self.merge_thresh,
                                       freq_min=self.freq_min,
                                       freq_max=self.freq_max,
                                       pc_per_chan=self.pc_per_chan,
                                       minFR=self.minFR)
            SFMdaSortingExtractor.write_sorting(sorting=sorting,
                                                save_path=self.firings_out)
        except:
            if os.path.exists(tmpdir):
                if not _keep_temp_files:
                    print('removing tmpdir1')
                    shutil.rmtree(tmpdir)
            raise
        if not _keep_temp_files:
            print('removing tmpdir2')
            shutil.rmtree(tmpdir)
示例#7
0
 def run(self):
     ironclust_src=os.environ.get('IRONCLUST_SRC',None)
     if not ironclust_src:
         raise Exception('Environment variable not set: IRONCLUST_SRC')
     code=''.join(random.choice(string.ascii_uppercase) for x in range(10))
     tmpdir=os.environ.get('TEMPDIR','/tmp')+'/ironclust-tmp-'+code
         
     try:
         recording=si.MdaRecordingExtractor(self.recording_dir)
         if len(self.channels)>0:
           recording=si.SubRecordingExtractor(parent_recording=recording,channel_ids=self.channels)
         if not os.path.exists(tmpdir):
             os.mkdir(tmpdir)
         sorting=sf.sorters.ironclust(
             recording=recording,
             tmpdir=tmpdir, ## TODO
             detect_sign=self.detect_sign,
             adjacency_radius=self.adjacency_radius,
             detect_threshold=self.detect_threshold,
             merge_thresh=self.merge_thresh,
             freq_min=self.freq_min,
             freq_max=self.freq_max,
             pc_per_chan=self.pc_per_chan,
             prm_template_name=self.prm_template_name,
             ironclust_src=ironclust_src
         )
         si.MdaSortingExtractor.writeSorting(sorting=sorting,save_path=self.firings_out)
     except:
         if os.path.exists(tmpdir):
             shutil.rmtree(tmpdir)
         raise
     shutil.rmtree(tmpdir)
示例#8
0
 def run(self):
     recording=si.MdaRecordingExtractor(dataset_directory=self.recording_dir)
     if len(self.channels)>0:
         recording=si.SubRecordingExtractor(parent_recording=recording,channel_ids=self.channels)
     sorting=si.MdaSortingExtractor(firings_file=self.firings)
     sw.UnitWaveformsWidget(recording=recording,sorting=sorting).plot()
     fname=save_plot(self.plot_out)
示例#9
0
    def run(self):
        code = ''.join(
            random.choice(string.ascii_uppercase) for x in range(10))
        tmpdir = os.environ.get('TEMPDIR', '/tmp') + '/kilosort-tmp-' + code

        try:
            recording = si.MdaRecordingExtractor(self.recording_dir)
            if len(self.channels) > 0:
                recording = si.SubRecordingExtractor(
                    parent_recording=recording, channel_ids=self.channels)
            if not os.path.exists(tmpdir):
                os.mkdir(tmpdir)
            sorting = sorters.kilosort(recording=recording,
                                       tmpdir=tmpdir,
                                       detect_sign=self.detect_sign,
                                       adjacency_radius=self.adjacency_radius,
                                       detect_threshold=self.detect_threshold,
                                       merge_thresh=self.merge_thresh,
                                       freq_min=self.freq_min,
                                       freq_max=self.freq_max,
                                       pc_per_chan=self.pc_per_chan)
            si.MdaSortingExtractor.writeSorting(sorting=sorting,
                                                save_path=self.firings_out)
        except:
            if os.path.exists(tmpdir):
                shutil.rmtree(tmpdir)
            raise
        shutil.rmtree(tmpdir)
示例#10
0
    def run(self):
        tmpdir = _get_tmpdir('jrclust')

        try:
            recording = SFMdaRecordingExtractor(self.recording_dir)
            params = read_dataset_params(self.recording_dir)
            if len(self.channels) > 0:
                recording = se.SubRecordingExtractor(
                    parent_recording=recording, channel_ids=self.channels)
            if not os.path.exists(tmpdir):
                os.mkdir(tmpdir)

            all_params = dict()
            for param0 in self.PARAMETERS:
                all_params[param0.name] = getattr(self, param0.name)

            sorting = jrclust_helper(
                recording=recording,
                tmpdir=tmpdir,
                params=params,
                **all_params,
            )
            SFMdaSortingExtractor.write_sorting(sorting=sorting,
                                                save_path=self.firings_out)
        except:
            if os.path.exists(tmpdir):
                if not getattr(self, '_keep_temp_files', False):
                    shutil.rmtree(tmpdir)
            raise
        if not getattr(self, '_keep_temp_files', False):
            shutil.rmtree(tmpdir)
示例#11
0
 def run(self):
     recording=si.MdaRecordingExtractor(dataset_directory=self.recording_dir,download=False)
     if len(self.channels)>0:
         recording=si.SubRecordingExtractor(parent_recording=recording,channel_ids=self.channels)
     sorting=si.MdaSortingExtractor(firings_file=self.firings)
     sw.CrossCorrelogramsWidget(samplerate=recording.getSamplingFrequency(),sorting=sorting).plot()
     fname=save_plot(self.plot_out)
示例#12
0
 def _on_view_timeseries(self):
     rx = self._recording.recordingExtractor()
     sf = rx.getSamplingFrequency()
     if self._recording.recordingFileIsLocal():
         rx = se.SubRecordingExtractor(parent_recording=rx,
                                       start_frame=int(sf * 0),
                                       end_frame=int(sf * 10))
     else:
         rx = se.SubRecordingExtractor(parent_recording=rx,
                                       start_frame=int(sf * 0),
                                       end_frame=int(sf * 1))
     rx = st.preprocessing.bandpass_filter(recording=rx,
                                           freq_min=300,
                                           freq_max=6000)
     self._view = SFW.TimeseriesWidget(recording=rx)
     self.refresh()
示例#13
0
    def test_ttl_frames_in_sub_multi(self):
        # sub recording
        start_frame = self.example_info['num_frames'] // 3
        end_frame = 2 * self.example_info['num_frames'] // 3
        RX_sub = se.SubRecordingExtractor(self.RX,
                                          start_frame=start_frame,
                                          end_frame=end_frame)
        original_ttls = self.RX.get_ttl_events()[0]
        ttls_in_sub = original_ttls[np.where((original_ttls >= start_frame)
                                             & (original_ttls < end_frame))[0]]
        self.assertTrue(
            np.array_equal(RX_sub.get_ttl_events()[0],
                           ttls_in_sub - start_frame))

        # multirecording
        RX_multi = se.MultiRecordingTimeExtractor(
            recordings=[self.RX, self.RX, self.RX])
        ttls_originals = self.RX.get_ttl_events()[0]
        num_ttls = len(ttls_originals)
        self.assertEqual(len(RX_multi.get_ttl_events()[0]), 3 * num_ttls)
        self.assertTrue(
            np.array_equal(RX_multi.get_ttl_events()[0][:num_ttls],
                           ttls_originals))
        self.assertTrue(
            np.array_equal(RX_multi.get_ttl_events()[0][num_ttls:2 * num_ttls],
                           ttls_originals + self.RX.get_num_frames()))
        self.assertTrue(
            np.array_equal(RX_multi.get_ttl_events()[0][2 * num_ttls:],
                           ttls_originals + 2 * self.RX.get_num_frames()))
示例#14
0
 def recordingExtractor(self, download=False):
     X = si.MdaRecordingExtractor(dataset_directory=self.directory(),
                                  download=download)
     if 'channels' in self._obj:
         if self._obj['channels']:
             X = si.SubRecordingExtractor(parent_recording=X,
                                          channel_ids=self._obj['channels'])
     return X
示例#15
0
 def __init__(self, context):
     vd.Component.__init__(self)
     self._context = context
     rx = self._context.recording.recordingExtractor()
     sf = rx.getSamplingFrequency()
     print(self._context.recording.recordingFileIsLocal())
     if self._context.recording.recordingFileIsLocal():
         rx = se.SubRecordingExtractor(parent_recording=rx,
                                       start_frame=int(sf * 0),
                                       end_frame=int(sf * 10))
     else:
         rx = se.SubRecordingExtractor(parent_recording=rx,
                                       start_frame=int(sf * 0),
                                       end_frame=int(sf * 1))
     rx = st.preprocessing.bandpass_filter(recording=rx,
                                           freq_min=300,
                                           freq_max=6000)
     self._timeseries_widget = TimeseriesWidget(recording=rx)
 def run(self):
   ret={}
   recording=si.MdaRecordingExtractor(dataset_directory=self.recording_dir,download=False)
   if len(self.channels)>0:
     recording=si.SubRecordingExtractor(parent_recording=recording,channel_ids=self.channels)
   ret['samplerate']=recording.getSamplingFrequency()
   ret['num_channels']=len(recording.getChannelIds())
   ret['duration_sec']=recording.getNumFrames()/ret['samplerate']
   write_json_file(self.json_out,ret)
示例#17
0
 def createSession(self):
     recording = SFMdaRecordingExtractor(
         dataset_directory=self._recording_directory, download=False)
     recording = se.SubRecordingExtractor(
         parent_recording=recording, start_frame=0, end_frame=10000)
     recording = se.NumpyRecordingExtractor(
         timeseries=recording.get_traces(), samplerate=recording.get_sampling_frequency())
     W = SFW.TimeseriesWidget(recording=recording)
     _make_full_browser(W)
     return W
示例#18
0
def compute_units_info(*, recording, sorting, channel_ids=[], unit_ids=[]):
    if (channel_ids) and (len(channel_ids) > 0):
        recording = si.SubRecordingExtractor(parent_recording=recording,
                                             channel_ids=channel_ids)

    # load into memory
    print('Loading recording into RAM...')
    recording = si.NumpyRecordingExtractor(
        timeseries=recording.get_traces(),
        samplerate=recording.get_sampling_frequency())

    # do filtering
    print('Filtering...')
    recording = bandpass_filter(recording=recording,
                                freq_min=300,
                                freq_max=6000)
    recording = si.NumpyRecordingExtractor(
        timeseries=recording.get_traces(),
        samplerate=recording.get_sampling_frequency())

    if (not unit_ids) or (len(unit_ids) == 0):
        unit_ids = sorting.get_unit_ids()

    print('Computing channel noise levels...')
    channel_noise_levels = compute_channel_noise_levels(recording=recording)

    # No longer use subset to compute the templates
    print('Computing unit templates...')
    templates = compute_unit_templates(recording=recording,
                                       sorting=sorting,
                                       unit_ids=unit_ids,
                                       max_num=100)

    print(recording.get_channel_ids())

    ret = []
    for i, unit_id in enumerate(unit_ids):
        print('Unit {} of {} (id={})'.format(i + 1, len(unit_ids), unit_id))
        template = templates[i]
        max_p2p_amps_on_channels = np.max(template, axis=1) - np.min(template,
                                                                     axis=1)
        peak_channel_index = np.argmax(max_p2p_amps_on_channels)
        peak_channel = recording.get_channel_ids()[peak_channel_index]
        peak_signal = np.max(np.abs(template[peak_channel_index, :]))
        info0 = dict()
        info0['unit_id'] = int(unit_id)
        info0['snr'] = peak_signal / channel_noise_levels[peak_channel_index]
        info0['peak_channel'] = int(recording.get_channel_ids()[peak_channel])
        train = sorting.get_unit_spike_train(unit_id=unit_id)
        info0['num_events'] = int(len(train))
        info0['firing_rate'] = float(
            len(train) /
            (recording.get_num_frames() / recording.get_sampling_frequency()))
        ret.append(info0)
    return ret
 def run(self):
   R0=si.MdaRecordingExtractor(dataset_directory=self.recording_dir,download=False)
   if len(self.channels)>0:
     R0=si.SubRecordingExtractor(parent_recording=R0,channel_ids=self.channels)
   R=sw.lazyfilters.bandpass_filter(recording=R0,freq_min=300,freq_max=6000)
   N=R.getNumFrames()
   N2=int(N/2)
   channels=R.getChannelIds()
   if len(channels)>20: channels=channels[0:20]
   sw.TimeseriesWidget(recording=R,trange=[N2-4000,N2+0],channels=channels,width=12,height=5).plot()
   save_plot(self.jpg_out)
示例#20
0
  def run(self):
    import spikewidgets as sw
    
    R0=si.MdaRecordingExtractor(dataset_directory=self.recording_dir,download=True)
    if (self.channel_ids) and (len(self.channel_ids)>0):
      R0=si.SubRecordingExtractor(parent_recording=R0,channel_ids=self.channel_ids)
    
    recording = R0
    # recording=sw.lazyfilters.bandpass_filter(recording=R0,freq_min=300,freq_max=6000)

    sorting=si.MdaSortingExtractor(firings_file=self.firings)
    unit_ids=self.unit_ids
    if (not unit_ids) or (len(unit_ids)==0):
      unit_ids=sorting.getUnitIds()
  
    channel_noise_levels=compute_channel_noise_levels(recording=recording)

    # No longer use subset to compute the templates
    templates=compute_unit_templates(recording=recording,sorting=sorting,unit_ids=unit_ids,max_num=100)

    ret=[]
    for i,unit_id in enumerate(unit_ids):
      template=templates[i]
      max_p2p_amps_on_channels=np.max(template,axis=1)-np.min(template,axis=1)
      peak_channel_index=np.argmax(max_p2p_amps_on_channels)
      peak_channel=recording.getChannelIds()[peak_channel_index]
      R1=si.SubRecordingExtractor(parent_recording=recording,channel_ids=[peak_channel_index])
      R1f=sw.lazyfilters.bandpass_filter(recording=R1,freq_min=300,freq_max=6000)
      templates2=compute_unit_templates(recording=R1f,sorting=sorting,unit_ids=[unit_id],max_num=100)
      template2=templates2[0]
      info0=dict()
      info0['unit_id']=int(unit_id)
      info0['snr']=np.max(np.abs(template2))/channel_noise_levels[peak_channel_index]
      #info0['snr']=compute_template_snr(template,channel_noise_levels)
      #peak_channel_index=np.argmax(np.max(np.abs(template),axis=1))
      info0['peak_channel']=int(recording.getChannelIds()[peak_channel])
      train=sorting.getUnitSpikeTrain(unit_id=unit_id)
      info0['num_events']=int(len(train))
      info0['firing_rate']=float(len(train)/(recording.getNumFrames()/recording.getSamplingFrequency()))
      ret.append(info0)
    write_json_file(self.json_out,ret)
示例#21
0
    def __init__(self, arg: Union[str, dict], download: bool=False):
        super().__init__()
        obj = _create_object_for_arg(arg)
        assert obj is not None
        self._object: dict = obj
        
        recording_format = self._object['recording_format']
        data: dict = self._object['data']
        if recording_format == 'mda':
            self._recording: se.RecordingExtractor = MdaRecordingExtractor(timeseries_path=data['raw'], samplerate=data['params']['samplerate'], geom=np.array(data['geom']), download=download)
        elif recording_format == 'nrs':
            self._recording: se.RecordingExtractor = NrsRecordingExtractor(**data)
        # elif recording_format == 'nwb':
        #     path0 = kp.load_file(data['path'])
        #     self._recording: se.RecordingExtractor = NwbRecordingExtractor(path0, electrical_series_name='e-series')
        elif recording_format == 'bin1':
            self._recording: se.RecordingExtractor = Bin1RecordingExtractor(**data, p2p=True)
        elif recording_format == 'snippets1':
            self._recording: se.RecordingExtractor = Snippets1RecordingExtractor(snippets_h5_uri=data['snippets_h5_uri'], p2p=True)
        elif recording_format == 'subrecording':
            R = LabboxEphysRecordingExtractor(data['recording'], download=download)
            if 'channel_ids' in data:
                channel_ids = np.array(data['channel_ids'])
            elif 'group' in data:
                channel_ids = np.array(R.get_channel_ids())
                groups = R.get_channel_groups(channel_ids=R.get_channel_ids())
                group = int(data['group'])
                inds = np.where(np.array(groups) == group)[0]
                channel_ids = channel_ids[inds]
            elif 'groups' in data:
                raise Exception('This case not yet handled.')
            else:
                channel_ids = None
            if 'start_frame' in data:
                start_frame = data['start_frame']
                end_frame = data['end_frame']
            else:
                start_frame = None
                end_frame = None
            self._recording: se.RecordingExtractor = se.SubRecordingExtractor(
                parent_recording=R,
                channel_ids=channel_ids,
                start_frame=start_frame,
                end_frame=end_frame
            )
        elif recording_format == 'filtered':
            R = LabboxEphysRecordingExtractor(data['recording'], download=download)
            self._recording: se.RecordingExtractor = _apply_filters(recording=R, filters=data['filters'])
        else:
            raise Exception(f'Unexpected recording format: {recording_format}')

        self.copy_channel_properties(recording=self._recording)
def get_unit_waveforms(recording, sorting, unit_ids, channel_ids_by_unit,
                       snippet_len):
    if not isinstance(snippet_len, list) and not isinstance(
            snippet_len, tuple):
        b = int(snippet_len / 2)
        a = int(snippet_len) - b
        snippet_len = [a, b]

    num_channels = recording.get_num_channels()
    num_frames = recording.get_num_frames()
    num_bytes_per_chunk = 1000 * 1000 * 1000  # ? how to choose this
    num_bytes_per_frame = num_channels * 2
    chunk_size = num_bytes_per_chunk / num_bytes_per_frame
    padding_size = 100 + snippet_len[0] + snippet_len[
        1]  # a bit excess padding
    chunks = _divide_recording_into_time_chunks(num_frames=num_frames,
                                                chunk_size=chunk_size,
                                                padding_size=padding_size)
    all_unit_waveforms = [[] for ii in range(len(unit_ids))]
    for ii, chunk in enumerate(chunks):
        # chunk: {istart, iend, istart_with_padding, iend_with_padding} # include padding
        print(
            f'Processing chunk {ii + 1} of {len(chunks)}; chunk-range: {chunk["istart_with_padding"]} {chunk["iend_with_padding"]}; num-frames: {num_frames}'
        )
        recording_chunk = se.SubRecordingExtractor(
            parent_recording=recording,
            start_frame=chunk['istart_with_padding'],
            end_frame=chunk['iend_with_padding'])
        # note that the efficiency of this operation may need improvement (really depends on sorting extractor implementation)
        sorting_chunk = se.SubSortingExtractor(parent_sorting=sorting,
                                               start_frame=chunk['istart'],
                                               end_frame=chunk['iend'])
        print(f'Getting unit waveforms for chunk {ii + 1} of {len(chunks)}')
        # num_events_in_chunk x num_channels_in_nbhd[unit_id] x len_of_one_snippet
        unit_waveforms = _get_unit_waveforms_for_chunk(
            recording=recording_chunk,
            sorting=sorting_chunk,
            frame_offset=chunk['istart'] - chunk[
                'istart_with_padding'],  # just the padding size (except 0 for first chunk)
            unit_ids=unit_ids,
            snippet_len=snippet_len,
            channel_ids_by_unit=channel_ids_by_unit)
        for i_unit, x in enumerate(unit_waveforms):
            all_unit_waveforms[i_unit].append(x)

    # concatenate the results over the chunks
    unit_waveforms = [
        # tot_num_events_for_unit x num_channels_in_nbhd[unit_id] x len_of_one_snippet
        np.concatenate(all_unit_waveforms[i_unit], axis=0)
        for i_unit in range(len(unit_ids))
    ]
    return unit_waveforms
 def run(self):
   R0=si.MdaRecordingExtractor(dataset_directory=self.recording_dir,download=True)
   if len(self.channels)>0:
     R0=si.SubRecordingExtractor(parent_recording=R0,channel_ids=self.channels)
   R=sw.lazyfilters.bandpass_filter(recording=R0,freq_min=300,freq_max=6000)
   S=si.MdaSortingExtractor(firings_file=self.firings)
   channels=R.getChannelIds()
   if len(channels)>20:
     channels=channels[0:20]
   if len(self.units)>0:
     units=self.units
   else:
     units=S.getUnitIds()
   if len(units)>20:
     units=units[::int(len(units)/20)]
   sw.UnitWaveformsWidget(recording=R,sorting=S,channels=channels,unit_ids=units).plot()
   save_plot(self.jpg_out)
示例#24
0
    def test_multi_sub_recording_extractor(self):
        RX_multi = se.MultiRecordingTimeExtractor(
            recordings=[self.RX, self.RX, self.RX],
            epoch_names=['A', 'B', 'C'])
        RX_sub = RX_multi.get_epoch('C')
        self._check_recordings_equal(self.RX, RX_sub)
        self.assertEqual(4, len(RX_sub.get_channel_ids()))

        RX_multi = se.MultiRecordingChannelExtractor(
            recordings=[self.RX, self.RX2, self.RX3], groups=[1, 2, 3])
        print(RX_multi.get_channel_groups())
        RX_sub = se.SubRecordingExtractor(RX_multi,
                                          channel_ids=[4, 5, 6, 7],
                                          renamed_channel_ids=[0, 1, 2, 3])
        self._check_recordings_equal(self.RX2, RX_sub)
        self.assertEqual([2, 2, 2, 2], RX_sub.get_channel_groups())
        self.assertEqual(12, len(RX_multi.get_channel_ids()))
示例#25
0
def get_max_channels_per_waveforms(recording, grouping_property, channel_ids, max_channels_per_waveforms):
    if grouping_property is None:
        if max_channels_per_waveforms is None:
            n_channels = len(channel_ids)
        elif max_channels_per_waveforms >= len(channel_ids):
            n_channels = len(channel_ids)
        else:
            n_channels = max_channels_per_waveforms
    else:
        rec = se.SubRecordingExtractor(recording, channel_ids=channel_ids)
        rec_groups = np.array(rec.get_channel_groups())
        groups, count = np.unique(rec_groups, return_counts=True)
        if max_channels_per_waveforms is None:
            n_channels = np.max(count)
        elif max_channels_per_waveforms >= np.max(count):
            n_channels = np.max(count)
        else:
            n_channels = max_channels_per_waveforms
    return n_channels
    def test_dump_load_multi_sub_extractor(self):
        # generate dumpable formats
        path1 = self.test_dir + '/mda'
        path2 = path1 + '/firings_true.mda'
        se.MdaRecordingExtractor.write_recording(self.RX, path1)
        se.MdaSortingExtractor.write_sorting(self.SX, path2)
        RX_mda = se.MdaRecordingExtractor(path1)
        SX_mda = se.MdaSortingExtractor(path2)

        RX_multi_chan = se.MultiRecordingChannelExtractor(recordings=[RX_mda, RX_mda, RX_mda])
        check_dumping(RX_multi_chan)
        RX_multi_time = se.MultiRecordingTimeExtractor(recordings=[RX_mda, RX_mda, RX_mda], )
        check_dumping(RX_multi_time)
        RX_multi_chan = se.SubRecordingExtractor(RX_mda, channel_ids=[0, 1])
        check_dumping(RX_multi_chan)

        SX_sub = se.SubSortingExtractor(SX_mda, unit_ids=[1, 2])
        check_dumping(SX_sub)
        SX_multi = se.MultiSortingExtractor(sortings=[SX_mda, SX_mda, SX_mda])
        check_dumping(SX_multi)
示例#27
0
 def run(self):
     recording=si.MdaRecordingExtractor(self.recording_dir)
     if len(self.channels)>0:
       recording=si.SubRecordingExtractor(parent_recording=recording,channel_ids=self.channels)
     num_workers=os.environ.get('NUM_WORKERS',None)
     if num_workers:
         num_workers=int(num_workers)
     sorting=sf.sorters.mountainsort4(
         recording=recording,
         detect_sign=self.detect_sign,
         adjacency_radius=self.adjacency_radius,
         freq_min=self.freq_min,
         freq_max=self.freq_max,
         whiten=self.whiten,
         clip_size=self.clip_size,
         detect_threshold=self.detect_threshold,
         detect_interval=self.detect_interval,
         noise_overlap_threshold=self.noise_overlap_threshold,
         num_workers=num_workers
     )
     si.MdaSortingExtractor.writeSorting(sorting=sorting,save_path=self.firings_out)
示例#28
0
    def run(self):
        ironclust_path = os.environ.get('IRONCLUST_PATH', None)
        if not ironclust_path:
            raise Exception('Environment variable not set: IRONCLUST_PATH')

        code = ''.join(
            random.choice(string.ascii_uppercase) for x in range(10))
        tmpdir = os.environ.get('TEMPDIR', '/tmp') + '/ironclust-tmp-' + code

        try:
            recording = se.MdaRecordingExtractor(self.recording_dir)
            params = read_dataset_params(self.recording_dir)
            if len(self.channels) > 0:
                recording = se.SubRecordingExtractor(
                    parent_recording=recording, channel_ids=self.channels)
            if not os.path.exists(tmpdir):
                os.mkdir(tmpdir)
            sorting = ironclust_helper(
                recording=recording,
                tmpdir=tmpdir,
                detect_sign=self.detect_sign,
                adjacency_radius=self.adjacency_radius,
                detect_threshold=self.detect_threshold,
                merge_thresh=self.merge_thresh,
                freq_min=self.freq_min,
                freq_max=self.freq_max,
                pc_per_chan=self.pc_per_chan,
                prm_template_name=self.prm_template_name,
                ironclust_path=ironclust_path,
                params=params,
            )
            se.MdaSortingExtractor.writeSorting(sorting=sorting,
                                                save_path=self.firings_out)
        except:
            if os.path.exists(tmpdir):
                if not getattr(self, '_keep_temp_files', False):
                    shutil.rmtree(tmpdir)
            raise
        if not getattr(self, '_keep_temp_files', False):
            shutil.rmtree(tmpdir)
    def subset_recording(self, stub_test: bool = False):
        """
        Subset a recording extractor according to stub and channel subset options.

        Parameters
        ----------
        stub_test : bool, optional (default False)
        """
        kwargs = dict()

        if stub_test:
            num_frames = 100
            end_frame = min(
                [num_frames,
                 self.recording_extractor.get_num_frames()])
            kwargs.update(end_frame=end_frame)

        if self.subset_channels is not None:
            kwargs.update(channel_ids=self.subset_channels)

        recording_extractor = se.SubRecordingExtractor(
            self.recording_extractor, **kwargs)
        return recording_extractor
示例#30
0
def find_unit_peak_channels(recording, sorting, unit_ids):
    # Use the first part of the recording to estimate the peak channels
    sorting_shortened = SubsampledSortingExtractor(parent_sorting=sorting, max_events_per_unit=20, method='truncate')
    max_time = 0
    for unit_id in sorting_shortened.get_unit_ids():
        st = sorting_shortened.get_unit_spike_train(unit_id=unit_id)
        if len(st) > 0:
            max_time = max(max_time, np.max(st))
    recording_shortened = se.SubRecordingExtractor(parent_recording=recording, start_frame=0, end_frame=max_time + 1)
    unit_waveforms = get_unit_waveforms(
        recording=recording_shortened,
        sorting=sorting_shortened,
        unit_ids=unit_ids,
        channel_ids_by_unit=None,
        snippet_len=(10, 10)
    )
    channel_ids = recording.get_channel_ids()
    peak_channels = {}
    for ii, unit_id in enumerate(unit_ids):
        average_waveform = np.median(unit_waveforms[ii], axis=0)
        peak_channel_index = int(np.argmax(np.max(average_waveform, axis=1) - np.min(average_waveform, axis=1)))
        peak_channels[unit_id] = int(channel_ids[peak_channel_index])
    return peak_channels