Exemplo n.º 1
0
    def run(self):
        import spikesorters as sorters

        print('Klusta......')
        recording = SFMdaRecordingExtractor(self.recording_dir)

        code = ''.join(random.choice(string.ascii_uppercase)
                       for x in range(10))
        tmpdir = os.environ.get('TEMPDIR', '/tmp') + '/klusta-' + code

        sorter = sorters.KlustaSorter(
            recording=recording,
            output_folder=tmpdir,
            debug=True,
            delete_output_folder=True
        )

        sorter.set_params(
            adjacency_radius=self.adjacency_radius,
            detect_sign=self.detect_sign,
            threshold_strong_std_factor=self.threshold_strong_std_factor,
            threshold_weak_std_factor=self.threshold_weak_std_factor,
            n_features_per_channel=self.n_features_per_channel,
            num_starting_clusters=self.num_starting_clusters,
            extract_s_before=self.extract_s_before,
            extract_s_after=self.extract_s_after
        )

        timer = sorter.run()
        print('#SF-SORTER-RUNTIME#{:.3f}#'.format(timer))

        sorting = sorter.get_result()

        SFMdaSortingExtractor.write_sorting(
            sorting=sorting, save_path=self.firings_out)
Exemplo n.º 2
0
    def run(self):
        _keep_temp_files = True

        code = ''.join(
            random.choice(string.ascii_uppercase) for x in range(10))
        tmpdir = os.environ.get('TEMPDIR', '/tmp') + '/kilosort2-tmp-' + code

        try:
            recording = SFMdaRecordingExtractor(self.recording_dir)
            if len(self.channels) > 0:
                recording = se.SubRecordingExtractor(
                    parent_recording=recording, channel_ids=self.channels)
            if not os.path.exists(tmpdir):
                os.mkdir(tmpdir)
            sorting = kilosort2_helper(recording=recording,
                                       tmpdir=tmpdir,
                                       detect_sign=self.detect_sign,
                                       adjacency_radius=self.adjacency_radius,
                                       detect_threshold=self.detect_threshold,
                                       merge_thresh=self.merge_thresh,
                                       freq_min=self.freq_min,
                                       freq_max=self.freq_max,
                                       pc_per_chan=self.pc_per_chan,
                                       minFR=self.minFR)
            SFMdaSortingExtractor.write_sorting(sorting=sorting,
                                                save_path=self.firings_out)
        except:
            if os.path.exists(tmpdir):
                if not _keep_temp_files:
                    print('removing tmpdir1')
                    shutil.rmtree(tmpdir)
            raise
        if not _keep_temp_files:
            print('removing tmpdir2')
            shutil.rmtree(tmpdir)
Exemplo n.º 3
0
    def javascript_state_changed(self, prev_state, state):
        if not self._recording:
            recordingPath = state.get('recordingPath', None)
            if not recordingPath:
                return
            self.set_python_state(dict(status_message='Loading recording'))
            mt.configDownloadFrom(state.get('download_from'))
            X = SFMdaRecordingExtractor(dataset_directory=recordingPath,
                                        download=True)
            self.set_python_state(
                dict(numChannels=X.get_num_channels(),
                     numTimepoints=X.get_num_frames(),
                     samplerate=X.get_sampling_frequency(),
                     status_message='Loaded recording.'))
            self._recording = X
        else:
            X = self._recording

        SR = state.get('segmentsRequested', {})
        for key in SR.keys():
            aa = SR[key]
            if not self.get_python_state(key, None):
                self.set_python_state(
                    dict(status_message='Loading segment {}'.format(key)))
                data0 = self._load_data(aa['ds'], aa['ss'])
                data0_base64 = _mda32_to_base64(data0)
                state0 = {}
                state0[key] = dict(data=data0_base64, ds=aa['ds'], ss=aa['ss'])
                self.set_python_state(state0)
                self.set_python_state(
                    dict(status_message='Loaded segment {}'.format(key)))
Exemplo n.º 4
0
    def run(self):
        tmpdir = _get_tmpdir('jrclust')

        try:
            recording = SFMdaRecordingExtractor(self.recording_dir)
            params = read_dataset_params(self.recording_dir)
            if len(self.channels) > 0:
                recording = se.SubRecordingExtractor(
                    parent_recording=recording, channel_ids=self.channels)
            if not os.path.exists(tmpdir):
                os.mkdir(tmpdir)

            all_params = dict()
            for param0 in self.PARAMETERS:
                all_params[param0.name] = getattr(self, param0.name)

            sorting = jrclust_helper(
                recording=recording,
                tmpdir=tmpdir,
                params=params,
                **all_params,
            )
            SFMdaSortingExtractor.write_sorting(sorting=sorting,
                                                save_path=self.firings_out)
        except:
            if os.path.exists(tmpdir):
                if not getattr(self, '_keep_temp_files', False):
                    shutil.rmtree(tmpdir)
            raise
        if not getattr(self, '_keep_temp_files', False):
            shutil.rmtree(tmpdir)
Exemplo n.º 5
0
    def run(self):
        code = ''.join(
            random.choice(string.ascii_uppercase) for x in range(10))
        tmpdir = os.environ.get('TEMPDIR', '/tmp') + '/yass-tmp-' + code

        # num_workers = os.environ.get('NUM_WORKERS', 1)
        # print('num_workers: {}'.format(num_workers))
        try:
            recording = SFMdaRecordingExtractor(self.recording_dir)
            if len(self.channels) > 0:
                recording = se.SubRecordingExtractor(
                    parent_recording=recording, channel_ids=self.channels)
            if not os.path.exists(tmpdir):
                os.mkdir(tmpdir)
            sorting, _ = yass_helper(recording=recording,
                                     output_folder=tmpdir,
                                     probe_file=None,
                                     file_name=None,
                                     detect_sign=self.detect_sign,
                                     adjacency_radius=self.adjacency_radius,
                                     template_width_ms=self.template_width_ms,
                                     filter=self.filter)
            SFMdaSortingExtractor.write_sorting(sorting=sorting,
                                                save_path=self.firings_out)
            # shutil.copyfile(yaml_file, self.paramfile_out)
        except:
            if os.path.exists(tmpdir):
                # shutil.rmtree(tmpdir)
                print('not deleted tmpdir1')
            raise
        if not getattr(self, '_keep_temp_files', False):
            # shutil.rmtree(tmpdir)
            print('not deleted tmpdir2')
Exemplo n.º 6
0
 def recordingExtractor(self, download: bool = False):
     X = SFMdaRecordingExtractor(dataset_directory=self.directory(),
                                 download=download)
     if 'channels' in self._obj:
         if self._obj['channels']:
             X = si.SubRecordingExtractor(parent_recording=X,
                                          channel_ids=self._obj['channels'])
     return X
Exemplo n.º 7
0
def real(name='franklab_tetrode', download=True):
    if name == 'franklab_tetrode':
        dsdir = 'kbucket://b5ecdf1474c5/datasets/neuron_paper/franklab_tetrode'
        IX = SFMdaRecordingExtractor(dataset_directory=dsdir,
                                     download=download)
        return (IX, None)
    else:
        raise Exception('Unrecognized name for real dataset: ' + name)
Exemplo n.º 8
0
    def run(self):
        timer = time.time()
        import spikesorters as sorters
        print('IronClust......')

        try:
            ironclust_path = IronClust.install()
        except:
            traceback.print_exc()
            raise Exception('Problem installing ironclust.')
        sorters.IronClustSorter.set_ironclust_path(ironclust_path)


        recording = SFMdaRecordingExtractor(self.recording_dir)
        code = ''.join(random.choice(string.ascii_uppercase)
                       for x in range(10))
        tmpdir = os.environ.get('TEMPDIR', '/tmp') + '/ironclust-' + code

        sorter = sorters.IronClustSorter(
            recording=recording,
            output_folder=tmpdir,
            debug=True,
            delete_output_folder = False # will be taken care by _keep_temp_files one step above
        )

        sorter.set_params(
            detect_sign=self.detect_sign,
            adjacency_radius=self.adjacency_radius,
            adjacency_radius_out=self.adjacency_radius_out,
            detect_threshold=self.detect_threshold,
            prm_template_name=self.prm_template_name,
            freq_min=self.freq_min,
            freq_max=self.freq_max,
            merge_thresh=self.merge_thresh,
            pc_per_chan=self.pc_per_chan,
            whiten=self.whiten,
            filter_type=self.filter_type,
            filter_detect_type=self.filter_detect_type,
            common_ref_type=self.common_ref_type,
            batch_sec_drift=self.batch_sec_drift,
            step_sec_drift=self.step_sec_drift,
            knn=self.knn,
            min_count=self.min_count,
            fGpu=self.fGpu,
            fft_thresh=self.fft_thresh,
            fft_thresh_low=self.fft_thresh_low,
            nSites_whiten=self.nSites_whiten,
            feature_type=self.feature_type,
            delta_cut=self.delta_cut,
            post_merge_mode=self.post_merge_mode,
            sort_mode=self.sort_mode
        )     
        timer = sorter.run()
        print('#SF-SORTER-RUNTIME#{:.3f}#'.format(timer))
        sorting = sorter.get_result()

        SFMdaSortingExtractor.write_sorting(
            sorting=sorting, save_path=self.firings_out)
Exemplo n.º 9
0
 def run(self):
     rx = SFMdaRecordingExtractor(
         dataset_directory=self.recording_directory, download=True)
     rx2 = bandpass_filter(recording=rx,
                           freq_min=300,
                           freq_max=6000,
                           freq_wid=1000)
     if not mdaio.writemda32(rx2.get_traces(), self.timeseries_out):
         raise Exception('Unable to write output file.')
Exemplo n.º 10
0
 def createSession(self):
     recording = SFMdaRecordingExtractor(
         dataset_directory=self._recording_directory, download=False)
     recording = se.SubRecordingExtractor(
         parent_recording=recording, start_frame=0, end_frame=10000)
     recording = se.NumpyRecordingExtractor(
         timeseries=recording.get_traces(), samplerate=recording.get_sampling_frequency())
     W = SFW.TimeseriesWidget(recording=recording)
     _make_full_browser(W)
     return W
Exemplo n.º 11
0
 def run(self):
     recording = SFMdaRecordingExtractor(
         dataset_directory=self.recording_directory, download=True)
     sorting = SFMdaSortingExtractor(firings_file=self.firings)
     waveforms0 = _get_random_spike_waveforms(recording=recording,
                                              sorting=sorting,
                                              unit=self.unit_id)
     channel_ids = recording.get_channel_ids()
     avg_waveform = np.median(waveforms0, axis=2)
     ret = dict(channel_ids=channel_ids,
                average_waveform=avg_waveform.tolist())
     with open(self.json_out, 'w') as f:
         json.dump(ret, f)
Exemplo n.º 12
0
def yass_example(download=True, set_id=1):
    if set_id in range(1, 7):
        dsdir = 'kbucket://15734439d8cf/groundtruth/visapy_mea/set{}'.format(
            set_id)
        IX = SFMdaRecordingExtractor(dataset_directory=dsdir,
                                     download=download)
        path1 = os.path.join(dsdir, 'firings_true.mda')
        print(path1)
        OX = SFMdaSortingExtractor(path1)
        return (IX, OX)
    else:
        raise Exception(
            'Invalid ID for yass_example {} is not betewen 1..6'.format(
                set_id))
Exemplo n.º 13
0
    def run(self):
        from spikeforest import SFMdaRecordingExtractor, SFMdaSortingExtractor
        from spikeforest_common import autoScaleRecordingToNoiseLevel
        import spiketoolkit as st

        clustering_n_jobs = os.environ.get('NUM_WORKERS', None)
        if clustering_n_jobs is not None:
            clustering_n_jobs = int(clustering_n_jobs)

        code = ''.join(
            random.choice(string.ascii_uppercase) for x in range(10))
        tmpdir = os.environ.get('TEMPDIR', '/tmp') + '/hs2-tmp-' + code

        try:
            if not os.path.exists(tmpdir):
                os.mkdir(tmpdir)

            recording = SFMdaRecordingExtractor(self.recording_dir)
            # print('Auto scaling via normalize_by_quantile...')
            # recording = st.preprocessing.normalize_by_quantile(recording=recording, scale=200.0)
            # recording = autoScaleRecordingToNoiseLevel(recording, noise_level=32)

            print('Running HerdingspikesSorter...')
            os.environ['HS2_PROBE_PATH'] = tmpdir
            st_sorter = st.sorters.HerdingspikesSorter(recording=recording,
                                                       output_folder=tmpdir +
                                                       '/hs2_sorting_output')
            print('Using builtin bandpass and normalisation')
            hs2_par = st_sorter.default_params()
            hs2_par['filter'] = True
            hs2_par['pre_scale_value'] = 20
            hs2_par['pre_scale'] = True
            st_sorter.set_params(**hs2_par)
            if clustering_n_jobs is not None:
                st_sorter.set_params(clustering_n_jobs=clustering_n_jobs)
            timer = st_sorter.run()
            print('#SF-SORTER-RUNTIME#{:.3f}#'.format(timer))
            sorting = st_sorter.get_result()

            SFMdaSortingExtractor.write_sorting(sorting=sorting,
                                                save_path=self.firings_out)
        except:
            if os.path.exists(tmpdir):
                if not getattr(self, '_keep_temp_files', False):
                    shutil.rmtree(tmpdir)
            raise

        if not getattr(self, '_keep_temp_files', False):
            shutil.rmtree(tmpdir)
Exemplo n.º 14
0
 def run(self):
     rx = SFMdaRecordingExtractor(
         dataset_directory=self.recording_directory,
         download=True,
         raw_fname=self.filtered_timeseries)
     sx_true = SFMdaSortingExtractor(firings_file=self.firings_true)
     sx = SFMdaSortingExtractor(firings_file=self.firings_sorted)
     ssobj = create_spikesprays(rx=rx,
                                sx_true=sx_true,
                                sx_sorted=sx,
                                neighborhood_size=self.neighborhood_size,
                                num_spikes=self.num_spikes,
                                unit_id_true=self.unit_id_true,
                                unit_id_sorted=self.unit_id_sorted)
     with open(self.json_out, 'w') as f:
         json.dump(ssobj, f)
Exemplo n.º 15
0
    def run(self):
        if self.throw_error:
            import time
            print(
                'Intentionally throwing an error in 3 seconds (MountainSort4TestError)...'
            )
            sys.stdout.flush()
            time.sleep(3)
            raise Exception('Intentional error.')
        import ml_ms4alg

        print('MountainSort4......')
        recording = SFMdaRecordingExtractor(self.recording_dir)
        num_workers = os.environ.get('NUM_WORKERS', None)
        if num_workers:
            num_workers = int(num_workers)

        # Bandpass filter
        if self.freq_min or self.freq_max:
            recording = bandpass_filter(recording=recording,
                                        freq_min=self.freq_min,
                                        freq_max=self.freq_max)

        # Whiten
        if self.whiten:
            recording = whiten(recording=recording)

        # Sort
        sorting = ml_ms4alg.mountainsort4(
            recording=recording,
            detect_sign=self.detect_sign,
            adjacency_radius=self.adjacency_radius,
            clip_size=self.clip_size,
            detect_threshold=self.detect_threshold,
            detect_interval=self.detect_interval,
            num_workers=num_workers)

        # Curate
        # if self.noise_overlap_threshold is not None:
        #    sorting=ml_ms4alg.mountainsort4_curation(
        #      recording=recording,
        #      sorting=sorting,
        #      noise_overlap_threshold=self.noise_overlap_threshold
        #    )

        SFMdaSortingExtractor.write_sorting(sorting=sorting,
                                            save_path=self.firings_out)
Exemplo n.º 16
0
    def run(self):
        print('Running Tridesclous...')
        from spikeforest import SFMdaRecordingExtractor, SFMdaSortingExtractor
        # from spikeforest_common import autoScaleRecordingToNoiseLevel
        # import spiketoolkit as st
        import spikesorters

        code = ''.join(
            random.choice(string.ascii_uppercase) for x in range(10))
        tmpdir = os.environ.get('TEMPDIR', '/tmp') + '/tdc-tmp-' + code

        try:
            if not os.path.exists(tmpdir):
                os.mkdir(tmpdir)

            print('Loading recording...')
            recording = SFMdaRecordingExtractor(self.recording_dir)
            # print('Auto scaling via normalize_by_quantile...')
            # recording = st.preprocessing.normalize_by_quantile(recording=recording, scale=200.0)
            # recording = autoScaleRecordingToNoiseLevel(recording, noise_level=32)

            print('Running TridesclousSorter...')
            os.environ['HS2_PROBE_PATH'] = tmpdir
            st_sorter = spikesorters.TridesclousSorter(recording=recording,
                                                       output_folder=tmpdir +
                                                       '/tdc_sorting_output',
                                                       verbose=True)
            # setattr(st_sorter, 'debug', True)
            st_sorter
            timer = st_sorter.run()
            print('#SF-SORTER-RUNTIME#{:.3f}#'.format(timer))
            sorting = st_sorter.get_result()

            SFMdaSortingExtractor.write_sorting(sorting=sorting,
                                                save_path=self.firings_out)
        except:
            if os.path.exists(tmpdir):
                if not getattr(self, '_keep_temp_files', False):
                    shutil.rmtree(tmpdir)
            raise

        if not getattr(self, '_keep_temp_files', False):
            shutil.rmtree(tmpdir)
Exemplo n.º 17
0
    def run(self):
        # from spikeinterface/spikesorters
        import spikesorters as sorters

        print('MountainSort4......')
        recording = SFMdaRecordingExtractor(self.recording_dir)
        num_workers = os.environ.get('NUM_WORKERS', None)
        if num_workers:
            num_workers = int(num_workers)

        code = ''.join(random.choice(string.ascii_uppercase)
                       for x in range(10))
        tmpdir = os.environ.get('TEMPDIR', '/tmp') + '/mountainsort4-' + code

        sorter = sorters.Mountainsort4Sorter(
            recording=recording,
            output_folder=tmpdir,
            debug=True,
            delete_output_folder=True
        )

        sorter.set_params(
            detect_sign=self.detect_sign,
            adjacency_radius=self.adjacency_radius,
            clip_size=self.clip_size,
            detect_threshold=self.detect_threshold,
            detect_interval=self.detect_interval,
            num_workers=num_workers,
            curation=False,
            whiten=True,
            filter=True,
            freq_min=self.freq_min,
            freq_max=self.freq_max
        )

        # TODO: get elapsed time from the return of this run
        sorter.run()

        sorting = sorter.get_result()

        SFMdaSortingExtractor.write_sorting(
            sorting=sorting, save_path=self.firings_out)
Exemplo n.º 18
0
    def run(self):
        from .bandpass_filter import bandpass_filter
        from .whiten import whiten

        import ml_ms4alg

        print('MountainSort4......')
        recording = SFMdaRecordingExtractor(self.recording_dir)
        num_workers = os.environ.get('NUM_WORKERS', None)
        if num_workers:
            num_workers = int(num_workers)

        # Bandpass filter
        if self.freq_min or self.freq_max:
            recording = bandpass_filter(
                recording=recording, freq_min=self.freq_min, freq_max=self.freq_max)

        # Whiten
        if self.whiten:
            recording = whiten(recording=recording)

        # Sort
        sorting = ml_ms4alg.mountainsort4(
            recording=recording,
            detect_sign=self.detect_sign,
            adjacency_radius=self.adjacency_radius,
            clip_size=self.clip_size,
            detect_threshold=self.detect_threshold,
            detect_interval=self.detect_interval,
            num_workers=num_workers
        )

        # Curate
        # if self.noise_overlap_threshold is not None:
        #    sorting=ml_ms4alg.mountainsort4_curation(
        #      recording=recording,
        #      sorting=sorting,
        #      noise_overlap_threshold=self.noise_overlap_threshold
        #    )        

        SFMdaSortingExtractor.write_sorting(
            sorting=sorting, save_path=self.firings_out)
Exemplo n.º 19
0
    def run(self):

        import spikesorters as sorters
        print('Kilosort2......')

        try:
            kilosort2_path = KiloSort2.install()
        except:
            traceback.print_exc()
            raise Exception('Problem installing kilosort.')
        sorters.Kilosort2Sorter.set_kilosort2_path(kilosort2_path)

        recording = SFMdaRecordingExtractor(self.recording_dir)
        code = ''.join(
            random.choice(string.ascii_uppercase) for x in range(10))
        tmpdir = os.environ.get('TEMPDIR', '/tmp') + '/kilosort2-' + code

        sorter = sorters.Kilosort2Sorter(recording=recording,
                                         output_folder=tmpdir,
                                         debug=True,
                                         delete_output_folder=True)

        sorter.set_params(detect_threshold=self.detect_threshold,
                          car=self.car,
                          minFR=self.minFR,
                          electrode_dimensions=None,
                          freq_min=self.freq_min,
                          sigmaMask=self.adjacency_radius,
                          nPCs=self.pc_per_chan)

        timer = sorter.run()
        print('#SF-SORTER-RUNTIME#{:.3f}#'.format(timer))

        sorting = sorter.get_result()

        SFMdaSortingExtractor.write_sorting(sorting=sorting,
                                            save_path=self.firings_out)
Exemplo n.º 20
0
    def run(self):

        import spikesorters as sorters
        print('SpyKING CIRCUS......')
        recording = SFMdaRecordingExtractor(self.recording_dir)
        code = ''.join(
            random.choice(string.ascii_uppercase) for x in range(10))
        tmpdir = os.environ.get('TEMPDIR', '/tmp') + '/spyking-circus-' + code

        num_workers = int(os.environ.get('NUM_WORKERS', '1'))

        sorter = sorters.SpykingcircusSorter(recording=recording,
                                             output_folder=tmpdir,
                                             verbose=True,
                                             delete_output_folder=True)

        sorter.set_params(
            detect_sign=self.detect_sign,
            adjacency_radius=self.adjacency_radius,
            detect_threshold=self.detect_threshold,
            template_width_ms=self.template_width_ms,
            filter=self.filter,
            merge_spikes=True,
            auto_merge=0.5,
            num_workers=num_workers,
            electrode_dimensions=None,
            whitening_max_elts=self.whitening_max_elts,
            clustering_max_elts=self.clustering_max_elts,
        )

        timer = sorter.run()
        print('#SF-SORTER-RUNTIME#{:.3f}#'.format(timer))

        sorting = sorter.get_result()

        SFMdaSortingExtractor.write_sorting(sorting=sorting,
                                            save_path=self.firings_out)
Exemplo n.º 21
0
    def run(self):

        import spikesorters as sorters
        print('KiloSort......')

        try:
            kilosort_path = KiloSort.install()
        except:
            traceback.print_exc()
            raise Exception('Problem installing kilosort.')
        sorters.KilosortSorter.set_kilosort_path(kilosort_path)

        recording = SFMdaRecordingExtractor(self.recording_dir)
        code = ''.join(
            random.choice(string.ascii_uppercase) for x in range(10))
        tmpdir = os.environ.get('TEMPDIR', '/tmp') + '/kilosort-' + code

        sorter = sorters.KilosortSorter(recording=recording,
                                        output_folder=tmpdir,
                                        debug=True,
                                        delete_output_folder=True)

        sorter.set_params(detect_threshold=self.detect_threshold,
                          freq_min=self.freq_min,
                          freq_max=self.freq_max,
                          car=True,
                          useGPU=True,
                          electrode_dimensions=None)

        # TODO: get elapsed time from the return of this run
        sorter.run()

        sorting = sorter.get_result()

        SFMdaSortingExtractor.write_sorting(sorting=sorting,
                                            save_path=self.firings_out)
Exemplo n.º 22
0
    def initialize(self):
        if self._initialized:
            return
        self._initialized = True

        print('******** FORESTVIEW: Initializing recording context')
        self._recording_object = self._recording_object
        if self._download:
            print(
                '******** FORESTVIEW: Downloading recording file if needed...')
        recdir = self._recording_object['directory']
        raw_fname = self._recording_object.get('raw_fname', 'raw.mda')
        params_fname = self._recording_object.get('params_fname',
                                                  'params.json')
        self._rx = SFMdaRecordingExtractor(dataset_directory=recdir,
                                           download=self._download,
                                           raw_fname=raw_fname,
                                           params_fname=params_fname)
        self._rx = bandpass_filter(self._rx)

        if self._true_sorting_context:
            self._true_sorting_context.initialize()

        # firings_true_path = recdir + '/firings_true.mda'
        # self._sx_true = None
        # if mt.computeFileSha1(path=firings_true_path):
        #     print('******** FORESTVIEW: Downloading true firings file if needed...')
        #     if not mt.realizeFile(firings_true_path):
        #         print('Warning: unable to realize true firings file: '+firings_true_path)
        #     else:
        #         self._sx_true = SFMdaSortingExtractor(firings_file = firings_true_path)

        if self._intra_recording_context:
            self._intra_recording_context.initialize()

        print('******** FORESTVIEW: Done initializing recording context')
Exemplo n.º 23
0
    def run(self):
        import tridesclous as tdc

        tmpdir = Path(_get_tmpdir('tdc'))
        recording = SFMdaRecordingExtractor(self.recording_dir)

        params = {
            'fullchain_kargs': {
                'duration': 300.,
                'preprocessor': {
                    'highpass_freq': self.freq_min,
                    'lowpass_freq': self.freq_max,
                    'smooth_size': 0,
                    'chunksize': 1024,
                    'lostfront_chunksize': 128,
                    'signalpreprocessor_engine': 'numpy',
                    'common_ref_removal': self.common_ref_removal,
                },
                'peak_detector': {
                    'peakdetector_engine': 'numpy',
                    'peak_sign': '-',
                    'relative_threshold': self.detection_threshold,
                    'peak_span': self.peak_span,
                },
                'noise_snippet': {
                    'nb_snippet': 300,
                },
                'extract_waveforms': {
                    'n_left': self.waveforms_n_left,
                    'n_right': self.waveforms_n_right,
                    'mode': 'rand',
                    'nb_max': 20000,
                    'align_waveform': self.align_waveform,
                },
                'clean_waveforms': {
                    'alien_value_threshold': self.alien_value_threshold,
                },
            },
            'feat_method': 'peak_max',
            'feat_kargs': {},
            'clust_method': 'sawchaincut',
            'clust_kargs': {
                'kde_bandwith': 1.
            },
        }

        # save prb file:
        probe_file = tmpdir / 'probe.prb'
        se.save_probe_file(recording, probe_file, format='spyking_circus')

        # source file
        if isinstance(recording,
                      se.BinDatRecordingExtractor) and recording._frame_first:
            # no need to copy
            raw_filename = recording._datfile
            dtype = recording._timeseries.dtype.str
            nb_chan = len(recording._channels)
            offset = recording._timeseries.offset
        else:
            # save binary file (chunk by hcunk) into a new file
            raw_filename = tmpdir / 'raw_signals.raw'
            n_chan = recording.get_num_channels()
            chunksize = 2**24 // n_chan
            se.write_binary_dat_format(recording,
                                       raw_filename,
                                       time_axis=0,
                                       dtype='float32',
                                       chunksize=chunksize)
            dtype = 'float32'
            offset = 0

        # initialize source and probe file
        tdc_dataio = tdc.DataIO(dirname=str(tmpdir))
        nb_chan = recording.get_num_channels()

        tdc_dataio.set_data_source(
            type='RawData',
            filenames=[str(raw_filename)],
            dtype=dtype,
            sample_rate=recording.get_sampling_frequency(),
            total_channel=nb_chan,
            offset=offset)
        tdc_dataio.set_probe_file(str(probe_file))

        try:
            sorting = tdc_helper(tmpdir=tmpdir,
                                 params=params,
                                 recording=recording)
            SFMdaSortingExtractor.write_sorting(sorting=sorting,
                                                save_path=self.firings_out)
        except:
            if os.path.exists(tmpdir):
                if not getattr(self, '_keep_temp_files', False):
                    shutil.rmtree(tmpdir)
            raise
        if not getattr(self, '_keep_temp_files', False):
            shutil.rmtree(tmpdir)
Exemplo n.º 24
0
#!/usr/bin/env python

from spikeforest import SFMdaRecordingExtractor, SFMdaSortingExtractor
from mountaintools import client as mt

# Configure to download from the public spikeforest kachery node
mt.configDownloadFrom('spikeforest.public')

# Load an example tetrode recording with its ground truth
# You can also substitute any of the other available recordings
recdir = 'sha1dir://fb52d510d2543634e247e0d2d1d4390be9ed9e20.synth_magland/datasets_noise10_K10_C4/001_synth'

print('loading recording...')
recording = SFMdaRecordingExtractor(dataset_directory=recdir, download=True)
sorting_true = SFMdaSortingExtractor(firings_file=recdir + '/firings_true.mda')

# import a spike sorter from the spikesorters module of spikeforest
from spikeforestsorters import MountainSort4
import os
import shutil

# In place of MountainSort4 you could use any of the following:
#
# MountainSort4, SpykingCircus, KiloSort, KiloSort2, YASS
# IronClust, HerdingSpikes2, JRClust, Tridesclous, Klusta
# although the Matlab sorters require further setup.

# clear and create an empty output directory (keep things tidy)
if os.path.exists('test_outputs'):
    shutil.rmtree('test_outputs')
os.makedirs('test_outputs', exist_ok=True)