Beispiel #1
0
    def run(self):
        code = ''.join(
            random.choice(string.ascii_uppercase) for x in range(10))
        tmpdir = os.environ.get('TEMPDIR', '/tmp') + '/yass-tmp-' + code

        # num_workers = os.environ.get('NUM_WORKERS', 1)
        # print('num_workers: {}'.format(num_workers))
        try:
            recording = SFMdaRecordingExtractor(self.recording_dir)
            if len(self.channels) > 0:
                recording = se.SubRecordingExtractor(
                    parent_recording=recording, channel_ids=self.channels)
            if not os.path.exists(tmpdir):
                os.mkdir(tmpdir)
            sorting, _ = yass_helper(recording=recording,
                                     output_folder=tmpdir,
                                     probe_file=None,
                                     file_name=None,
                                     detect_sign=self.detect_sign,
                                     adjacency_radius=self.adjacency_radius,
                                     template_width_ms=self.template_width_ms,
                                     filter=self.filter)
            SFMdaSortingExtractor.write_sorting(sorting=sorting,
                                                save_path=self.firings_out)
            # shutil.copyfile(yaml_file, self.paramfile_out)
        except:
            if os.path.exists(tmpdir):
                # shutil.rmtree(tmpdir)
                print('not deleted tmpdir1')
            raise
        if not getattr(self, '_keep_temp_files', False):
            # shutil.rmtree(tmpdir)
            print('not deleted tmpdir2')
def _generate_toy_recordings():
    # generate toy recordings
    if not os.path.exists('toy_recordings'):
        os.mkdir('toy_recordings')

    replace_recordings = False

    ret = []
    for K in [5, 10, 15, 20]:
        recpath = 'toy_recordings/example_K{}'.format(K)
        if os.path.exists(recpath) and (replace_recordings):
            print('Generating toy recording: {}'.format(recpath))
            shutil.rmtree(recpath)
        else:
            print('Recording already exists: {}'.format(recpath))
        if not os.path.exists(recpath):
            rx, sx_true = example_datasets.toy_example1(duration=60,
                                                        num_channels=4,
                                                        samplerate=30000,
                                                        K=K)
            SFMdaRecordingExtractor.write_recording(recording=rx,
                                                    save_path=recpath)
            SFMdaSortingExtractor.write_sorting(sorting=sx_true,
                                                save_path=recpath +
                                                '/firings_true.mda')
        ret.append(
            dict(name='example_K{}'.format(K),
                 study='toy_study',
                 directory=os.path.abspath(recpath),
                 description='A toy recording with K={} units'.format(K)))

    return ret
 def compute_score(self, sorting_extractor):
     if self.metric != 'spikeforest':
         comparison = sc.compare_sorter_to_ground_truth(self.gt_se,
                                                           sorting_extractor, exhaustive_gt=True)
         d_results = comparison.get_performance(method='pooled_with_average', output='dict')
         print('results')
         print(d_results)
         if self.metric == 'accuracy':
             score = d_results['accuracy']
         if self.metric == 'precision':
             score = d_results['precision']
         if self.metric == 'recall':
             score = d_results['recall']
         if self.metric == 'f1':
             print('comparison:')
             print(d_results)
             if (d_results['precision']+d_results['recall']) > 0:
                 score = 2 * d_results['precision'] * d_results['recall'] / (d_results['precision']+d_results['recall'])
             else:
                 score = 0
         del comparison
     else:
         tmp_dir = 'test_outputs_spikeforest'
         SFMdaSortingExtractor.write_sorting(sorting=sorting_extractor, save_path=os.path.join(tmp_dir,'firings.mda'))
         print('Compare with ground truth...')
         sa.GenSortingComparisonTable.execute(firings=os.path.join(tmp_dir,'firings.mda'),
                                              firings_true=os.path.join(tmp_dir,'firings_true.mda'),
                                              units_true=self.true_units_above,  # use all units
                                              json_out=os.path.join(tmp_dir,'comparison.json'),
                                              html_out=os.path.join(tmp_dir,'comparison.html'),
                                              _container=None)
         comparison = mt.loadObject(path=os.path.join(tmp_dir,'comparison.json'))
         score = np.mean([float(u['accuracy']) for u in comparison.values()])
     return -score
Beispiel #4
0
    def run(self):
        tmpdir = _get_tmpdir('jrclust')

        try:
            recording = SFMdaRecordingExtractor(self.recording_dir)
            params = read_dataset_params(self.recording_dir)
            if len(self.channels) > 0:
                recording = se.SubRecordingExtractor(
                    parent_recording=recording, channel_ids=self.channels)
            if not os.path.exists(tmpdir):
                os.mkdir(tmpdir)

            all_params = dict()
            for param0 in self.PARAMETERS:
                all_params[param0.name] = getattr(self, param0.name)

            sorting = jrclust_helper(
                recording=recording,
                tmpdir=tmpdir,
                params=params,
                **all_params,
            )
            SFMdaSortingExtractor.write_sorting(sorting=sorting,
                                                save_path=self.firings_out)
        except:
            if os.path.exists(tmpdir):
                if not getattr(self, '_keep_temp_files', False):
                    shutil.rmtree(tmpdir)
            raise
        if not getattr(self, '_keep_temp_files', False):
            shutil.rmtree(tmpdir)
Beispiel #5
0
def gen_synth_datasets(datasets, *, outdir, samplerate=32000):
    if not os.path.exists(outdir):
        os.mkdir(outdir)
    for ds in datasets:
        ds_name = ds['name']
        print(ds_name)
        if 'seed' not in ds.keys():
            ds['seed'] = 0
        spiketrains = gen_spiketrains(
            duration=ds['duration'],
            n_exc=ds['n_exc'],
            n_inh=ds['n_inh'],
            f_exc=ds['f_exc'],
            f_inh=ds['f_inh'],
            min_rate=ds['min_rate'],
            st_exc=ds['st_exc'],
            st_inh=ds['st_inh'],
            seed=ds['seed']
        )
        OX = NeoSpikeTrainsOutputExtractor(
            spiketrains=spiketrains, samplerate=samplerate)
        X, geom = gen_recording(
            templates=ds['templates'],
            output_extractor=OX,
            noise_level=ds['noise_level'],
            samplerate=samplerate,
            duration=ds['duration']
        )
        IX = si.NumpyRecordingExtractor(
            timeseries=X, samplerate=samplerate, geom=geom)
        SFMdaRecordingExtractor.write_recording(
            IX, outdir+'/{}'.format(ds_name))
        SFMdaSortingExtractor.write_sorting(
            OX, outdir+'/{}/firings_true.mda'.format(ds_name))
    print('Done.')
Beispiel #6
0
    def run(self):
        from spikeforest import SFMdaRecordingExtractor, SFMdaSortingExtractor

        print('test1', self.firings_path, self.samplerate)

        sorting = SFMdaSortingExtractor(firings_file=self.firings_path)
        samplerate = self.samplerate
        max_samples = self.max_samples
        max_dt_msec = self.max_dt_msec
        bin_size_msec = self.bin_size_msec

        max_dt_tp = max_dt_msec * samplerate / 1000
        bin_size_tp = bin_size_msec * samplerate / 1000

        autocorrelograms = []
        for unit_id in sorting.get_unit_ids():
            print('Unit::g {}'.format(unit_id))
            (bin_counts, bin_edges) = compute_autocorrelogram(sorting.get_unit_spike_train(unit_id), max_dt_tp=max_dt_tp, bin_size_tp=bin_size_tp, max_samples=max_samples)
            autocorrelograms.append(dict(
                unit_id=unit_id,
                bin_counts=bin_counts,
                bin_edges=bin_edges
            ))
        ret = dict(
            autocorrelograms=autocorrelograms
        )
        with open(self.json_out, 'w') as f:
            json.dump(serialize_np(ret), f)
Beispiel #7
0
    def run(self):
        import spikesorters as sorters

        print('Klusta......')
        recording = SFMdaRecordingExtractor(self.recording_dir)

        code = ''.join(random.choice(string.ascii_uppercase)
                       for x in range(10))
        tmpdir = os.environ.get('TEMPDIR', '/tmp') + '/klusta-' + code

        sorter = sorters.KlustaSorter(
            recording=recording,
            output_folder=tmpdir,
            debug=True,
            delete_output_folder=True
        )

        sorter.set_params(
            adjacency_radius=self.adjacency_radius,
            detect_sign=self.detect_sign,
            threshold_strong_std_factor=self.threshold_strong_std_factor,
            threshold_weak_std_factor=self.threshold_weak_std_factor,
            n_features_per_channel=self.n_features_per_channel,
            num_starting_clusters=self.num_starting_clusters,
            extract_s_before=self.extract_s_before,
            extract_s_after=self.extract_s_after
        )

        timer = sorter.run()
        print('#SF-SORTER-RUNTIME#{:.3f}#'.format(timer))

        sorting = sorter.get_result()

        SFMdaSortingExtractor.write_sorting(
            sorting=sorting, save_path=self.firings_out)
Beispiel #8
0
    def run(self):
        _keep_temp_files = True

        code = ''.join(
            random.choice(string.ascii_uppercase) for x in range(10))
        tmpdir = os.environ.get('TEMPDIR', '/tmp') + '/kilosort2-tmp-' + code

        try:
            recording = SFMdaRecordingExtractor(self.recording_dir)
            if len(self.channels) > 0:
                recording = se.SubRecordingExtractor(
                    parent_recording=recording, channel_ids=self.channels)
            if not os.path.exists(tmpdir):
                os.mkdir(tmpdir)
            sorting = kilosort2_helper(recording=recording,
                                       tmpdir=tmpdir,
                                       detect_sign=self.detect_sign,
                                       adjacency_radius=self.adjacency_radius,
                                       detect_threshold=self.detect_threshold,
                                       merge_thresh=self.merge_thresh,
                                       freq_min=self.freq_min,
                                       freq_max=self.freq_max,
                                       pc_per_chan=self.pc_per_chan,
                                       minFR=self.minFR)
            SFMdaSortingExtractor.write_sorting(sorting=sorting,
                                                save_path=self.firings_out)
        except:
            if os.path.exists(tmpdir):
                if not _keep_temp_files:
                    print('removing tmpdir1')
                    shutil.rmtree(tmpdir)
            raise
        if not _keep_temp_files:
            print('removing tmpdir2')
            shutil.rmtree(tmpdir)
Beispiel #9
0
    def run(self):
        timer = time.time()
        import spikesorters as sorters
        print('IronClust......')

        try:
            ironclust_path = IronClust.install()
        except:
            traceback.print_exc()
            raise Exception('Problem installing ironclust.')
        sorters.IronClustSorter.set_ironclust_path(ironclust_path)


        recording = SFMdaRecordingExtractor(self.recording_dir)
        code = ''.join(random.choice(string.ascii_uppercase)
                       for x in range(10))
        tmpdir = os.environ.get('TEMPDIR', '/tmp') + '/ironclust-' + code

        sorter = sorters.IronClustSorter(
            recording=recording,
            output_folder=tmpdir,
            debug=True,
            delete_output_folder = False # will be taken care by _keep_temp_files one step above
        )

        sorter.set_params(
            detect_sign=self.detect_sign,
            adjacency_radius=self.adjacency_radius,
            adjacency_radius_out=self.adjacency_radius_out,
            detect_threshold=self.detect_threshold,
            prm_template_name=self.prm_template_name,
            freq_min=self.freq_min,
            freq_max=self.freq_max,
            merge_thresh=self.merge_thresh,
            pc_per_chan=self.pc_per_chan,
            whiten=self.whiten,
            filter_type=self.filter_type,
            filter_detect_type=self.filter_detect_type,
            common_ref_type=self.common_ref_type,
            batch_sec_drift=self.batch_sec_drift,
            step_sec_drift=self.step_sec_drift,
            knn=self.knn,
            min_count=self.min_count,
            fGpu=self.fGpu,
            fft_thresh=self.fft_thresh,
            fft_thresh_low=self.fft_thresh_low,
            nSites_whiten=self.nSites_whiten,
            feature_type=self.feature_type,
            delta_cut=self.delta_cut,
            post_merge_mode=self.post_merge_mode,
            sort_mode=self.sort_mode
        )     
        timer = sorter.run()
        print('#SF-SORTER-RUNTIME#{:.3f}#'.format(timer))
        sorting = sorter.get_result()

        SFMdaSortingExtractor.write_sorting(
            sorting=sorting, save_path=self.firings_out)
Beispiel #10
0
    def run(self):
        from spikeforest import SFMdaRecordingExtractor, SFMdaSortingExtractor
        from spikeforest_common import autoScaleRecordingToNoiseLevel
        import spiketoolkit as st

        clustering_n_jobs = os.environ.get('NUM_WORKERS', None)
        if clustering_n_jobs is not None:
            clustering_n_jobs = int(clustering_n_jobs)

        code = ''.join(
            random.choice(string.ascii_uppercase) for x in range(10))
        tmpdir = os.environ.get('TEMPDIR', '/tmp') + '/hs2-tmp-' + code

        try:
            if not os.path.exists(tmpdir):
                os.mkdir(tmpdir)

            recording = SFMdaRecordingExtractor(self.recording_dir)
            # print('Auto scaling via normalize_by_quantile...')
            # recording = st.preprocessing.normalize_by_quantile(recording=recording, scale=200.0)
            # recording = autoScaleRecordingToNoiseLevel(recording, noise_level=32)

            print('Running HerdingspikesSorter...')
            os.environ['HS2_PROBE_PATH'] = tmpdir
            st_sorter = st.sorters.HerdingspikesSorter(recording=recording,
                                                       output_folder=tmpdir +
                                                       '/hs2_sorting_output')
            print('Using builtin bandpass and normalisation')
            hs2_par = st_sorter.default_params()
            hs2_par['filter'] = True
            hs2_par['pre_scale_value'] = 20
            hs2_par['pre_scale'] = True
            st_sorter.set_params(**hs2_par)
            if clustering_n_jobs is not None:
                st_sorter.set_params(clustering_n_jobs=clustering_n_jobs)
            timer = st_sorter.run()
            print('#SF-SORTER-RUNTIME#{:.3f}#'.format(timer))
            sorting = st_sorter.get_result()

            SFMdaSortingExtractor.write_sorting(sorting=sorting,
                                                save_path=self.firings_out)
        except:
            if os.path.exists(tmpdir):
                if not getattr(self, '_keep_temp_files', False):
                    shutil.rmtree(tmpdir)
            raise

        if not getattr(self, '_keep_temp_files', False):
            shutil.rmtree(tmpdir)
Beispiel #11
0
 def run(self):
     rx = SFMdaRecordingExtractor(
         dataset_directory=self.recording_directory,
         download=True,
         raw_fname=self.filtered_timeseries)
     sx_true = SFMdaSortingExtractor(firings_file=self.firings_true)
     sx = SFMdaSortingExtractor(firings_file=self.firings_sorted)
     ssobj = create_spikesprays(rx=rx,
                                sx_true=sx_true,
                                sx_sorted=sx,
                                neighborhood_size=self.neighborhood_size,
                                num_spikes=self.num_spikes,
                                unit_id_true=self.unit_id_true,
                                unit_id_sorted=self.unit_id_sorted)
     with open(self.json_out, 'w') as f:
         json.dump(ssobj, f)
    def run(self):
        if self.throw_error:
            import time
            print(
                'Intentionally throwing an error in 3 seconds (MountainSort4TestError)...'
            )
            sys.stdout.flush()
            time.sleep(3)
            raise Exception('Intentional error.')
        import ml_ms4alg

        print('MountainSort4......')
        recording = SFMdaRecordingExtractor(self.recording_dir)
        num_workers = os.environ.get('NUM_WORKERS', None)
        if num_workers:
            num_workers = int(num_workers)

        # Bandpass filter
        if self.freq_min or self.freq_max:
            recording = bandpass_filter(recording=recording,
                                        freq_min=self.freq_min,
                                        freq_max=self.freq_max)

        # Whiten
        if self.whiten:
            recording = whiten(recording=recording)

        # Sort
        sorting = ml_ms4alg.mountainsort4(
            recording=recording,
            detect_sign=self.detect_sign,
            adjacency_radius=self.adjacency_radius,
            clip_size=self.clip_size,
            detect_threshold=self.detect_threshold,
            detect_interval=self.detect_interval,
            num_workers=num_workers)

        # Curate
        # if self.noise_overlap_threshold is not None:
        #    sorting=ml_ms4alg.mountainsort4_curation(
        #      recording=recording,
        #      sorting=sorting,
        #      noise_overlap_threshold=self.noise_overlap_threshold
        #    )

        SFMdaSortingExtractor.write_sorting(sorting=sorting,
                                            save_path=self.firings_out)
Beispiel #13
0
def load_sorting_results_info(firings_path, *, recording_path, epoch_name, ntrode_name, curated=False):
    if not mt.findFile(firings_path):
        return None
    sorting = SFMdaSortingExtractor(firings_file=firings_path)
    total_num_events = 0
    for unit_id in sorting.get_unit_ids():
        spike_times = sorting.get_unit_spike_train(unit_id=unit_id)
        total_num_events = total_num_events + len(spike_times)
    return dict(
        type='sorting_results',
        epoch_name=epoch_name,
        ntrode_name=ntrode_name,
        curated=curated,
        firings_path=firings_path,
        recording_path=recording_path,
        unit_ids=sorting.get_unit_ids(),
        num_events=total_num_events
    )
    def run(self):
        print('Running Tridesclous...')
        from spikeforest import SFMdaRecordingExtractor, SFMdaSortingExtractor
        # from spikeforest_common import autoScaleRecordingToNoiseLevel
        # import spiketoolkit as st
        import spikesorters

        code = ''.join(
            random.choice(string.ascii_uppercase) for x in range(10))
        tmpdir = os.environ.get('TEMPDIR', '/tmp') + '/tdc-tmp-' + code

        try:
            if not os.path.exists(tmpdir):
                os.mkdir(tmpdir)

            print('Loading recording...')
            recording = SFMdaRecordingExtractor(self.recording_dir)
            # print('Auto scaling via normalize_by_quantile...')
            # recording = st.preprocessing.normalize_by_quantile(recording=recording, scale=200.0)
            # recording = autoScaleRecordingToNoiseLevel(recording, noise_level=32)

            print('Running TridesclousSorter...')
            os.environ['HS2_PROBE_PATH'] = tmpdir
            st_sorter = spikesorters.TridesclousSorter(recording=recording,
                                                       output_folder=tmpdir +
                                                       '/tdc_sorting_output',
                                                       verbose=True)
            # setattr(st_sorter, 'debug', True)
            st_sorter
            timer = st_sorter.run()
            print('#SF-SORTER-RUNTIME#{:.3f}#'.format(timer))
            sorting = st_sorter.get_result()

            SFMdaSortingExtractor.write_sorting(sorting=sorting,
                                                save_path=self.firings_out)
        except:
            if os.path.exists(tmpdir):
                if not getattr(self, '_keep_temp_files', False):
                    shutil.rmtree(tmpdir)
            raise

        if not getattr(self, '_keep_temp_files', False):
            shutil.rmtree(tmpdir)
Beispiel #15
0
    def run(self):
        # from spikeinterface/spikesorters
        import spikesorters as sorters

        print('MountainSort4......')
        recording = SFMdaRecordingExtractor(self.recording_dir)
        num_workers = os.environ.get('NUM_WORKERS', None)
        if num_workers:
            num_workers = int(num_workers)

        code = ''.join(random.choice(string.ascii_uppercase)
                       for x in range(10))
        tmpdir = os.environ.get('TEMPDIR', '/tmp') + '/mountainsort4-' + code

        sorter = sorters.Mountainsort4Sorter(
            recording=recording,
            output_folder=tmpdir,
            debug=True,
            delete_output_folder=True
        )

        sorter.set_params(
            detect_sign=self.detect_sign,
            adjacency_radius=self.adjacency_radius,
            clip_size=self.clip_size,
            detect_threshold=self.detect_threshold,
            detect_interval=self.detect_interval,
            num_workers=num_workers,
            curation=False,
            whiten=True,
            filter=True,
            freq_min=self.freq_min,
            freq_max=self.freq_max
        )

        # TODO: get elapsed time from the return of this run
        sorter.run()

        sorting = sorter.get_result()

        SFMdaSortingExtractor.write_sorting(
            sorting=sorting, save_path=self.firings_out)
Beispiel #16
0
    def run(self):
        from .bandpass_filter import bandpass_filter
        from .whiten import whiten

        import ml_ms4alg

        print('MountainSort4......')
        recording = SFMdaRecordingExtractor(self.recording_dir)
        num_workers = os.environ.get('NUM_WORKERS', None)
        if num_workers:
            num_workers = int(num_workers)

        # Bandpass filter
        if self.freq_min or self.freq_max:
            recording = bandpass_filter(
                recording=recording, freq_min=self.freq_min, freq_max=self.freq_max)

        # Whiten
        if self.whiten:
            recording = whiten(recording=recording)

        # Sort
        sorting = ml_ms4alg.mountainsort4(
            recording=recording,
            detect_sign=self.detect_sign,
            adjacency_radius=self.adjacency_radius,
            clip_size=self.clip_size,
            detect_threshold=self.detect_threshold,
            detect_interval=self.detect_interval,
            num_workers=num_workers
        )

        # Curate
        # if self.noise_overlap_threshold is not None:
        #    sorting=ml_ms4alg.mountainsort4_curation(
        #      recording=recording,
        #      sorting=sorting,
        #      noise_overlap_threshold=self.noise_overlap_threshold
        #    )        

        SFMdaSortingExtractor.write_sorting(
            sorting=sorting, save_path=self.firings_out)
Beispiel #17
0
 def run(self):
     recording = SFMdaRecordingExtractor(
         dataset_directory=self.recording_directory, download=True)
     sorting = SFMdaSortingExtractor(firings_file=self.firings)
     waveforms0 = _get_random_spike_waveforms(recording=recording,
                                              sorting=sorting,
                                              unit=self.unit_id)
     channel_ids = recording.get_channel_ids()
     avg_waveform = np.median(waveforms0, axis=2)
     ret = dict(channel_ids=channel_ids,
                average_waveform=avg_waveform.tolist())
     with open(self.json_out, 'w') as f:
         json.dump(ret, f)
    def __init__(self, sorter, recording, gt_sorting, params_to_opt,
                 space=None, run_schedule=[50, 50], metric='accuracy',
                 recdir=None, outfile=None, x0=None, y0=None):
        self.sorter = sorter.lower()
        self.re = recording
        self.gt_se = gt_sorting
        self.params_to_opt = OrderedDict(params_to_opt)
        self.outfile = outfile
        self.run_schedule = run_schedule
        self.space = space
        self.best_parameters = None
        self.iteration = 0
        self.metric = metric.lower()
        self.recdir = recdir
        self.results_obj = None
        self.SorterClass = ss.sorter_dict[self.sorter]
        self.true_units_above = None
        self.x0 = x0 
        self.y0 = y0
        
        if self.metric == 'spikeforest':
            
            tmp_dir = 'test_outputs_spikeforest'
            if not os.path.exists(tmp_dir):
                print('Creating folder {} for temporary data - note this is not cleaned up.'.format(tmp_dir))
                os.makedirs(tmp_dir)
            SFMdaSortingExtractor.write_sorting(sorting=self.gt_se,
                                                save_path=os.path.join(tmp_dir,'firings_true.mda'))
            print('Compute units info...')
            sa.ComputeUnitsInfo.execute(recording_dir=self.recdir,
                                        firings=os.path.join(tmp_dir,'firings_true.mda'),
                                        json_out=os.path.join(tmp_dir,'true_units_info.json'))

            true_units_info = mt.loadObject(path=os.path.join(tmp_dir,'true_units_info.json'))
            true_units_info_by_unit_id = dict()
            snrthresh = 8
            self.true_units_above = [u['unit_id'] for u in true_units_info if u['snr'] > snrthresh]
            print('Only testing ground truth units with snr > 8: ',self.true_units_above)
Beispiel #19
0
def yass_example(download=True, set_id=1):
    if set_id in range(1, 7):
        dsdir = 'kbucket://15734439d8cf/groundtruth/visapy_mea/set{}'.format(
            set_id)
        IX = SFMdaRecordingExtractor(dataset_directory=dsdir,
                                     download=download)
        path1 = os.path.join(dsdir, 'firings_true.mda')
        print(path1)
        OX = SFMdaSortingExtractor(path1)
        return (IX, OX)
    else:
        raise Exception(
            'Invalid ID for yass_example {} is not betewen 1..6'.format(
                set_id))
    def run(self):

        import spikesorters as sorters
        print('SpyKING CIRCUS......')
        recording = SFMdaRecordingExtractor(self.recording_dir)
        code = ''.join(
            random.choice(string.ascii_uppercase) for x in range(10))
        tmpdir = os.environ.get('TEMPDIR', '/tmp') + '/spyking-circus-' + code

        num_workers = int(os.environ.get('NUM_WORKERS', '1'))

        sorter = sorters.SpykingcircusSorter(recording=recording,
                                             output_folder=tmpdir,
                                             verbose=True,
                                             delete_output_folder=True)

        sorter.set_params(
            detect_sign=self.detect_sign,
            adjacency_radius=self.adjacency_radius,
            detect_threshold=self.detect_threshold,
            template_width_ms=self.template_width_ms,
            filter=self.filter,
            merge_spikes=True,
            auto_merge=0.5,
            num_workers=num_workers,
            electrode_dimensions=None,
            whitening_max_elts=self.whitening_max_elts,
            clustering_max_elts=self.clustering_max_elts,
        )

        timer = sorter.run()
        print('#SF-SORTER-RUNTIME#{:.3f}#'.format(timer))

        sorting = sorter.get_result()

        SFMdaSortingExtractor.write_sorting(sorting=sorting,
                                            save_path=self.firings_out)
Beispiel #21
0
    def run(self):

        import spikesorters as sorters
        print('Kilosort2......')

        try:
            kilosort2_path = KiloSort2.install()
        except:
            traceback.print_exc()
            raise Exception('Problem installing kilosort.')
        sorters.Kilosort2Sorter.set_kilosort2_path(kilosort2_path)

        recording = SFMdaRecordingExtractor(self.recording_dir)
        code = ''.join(
            random.choice(string.ascii_uppercase) for x in range(10))
        tmpdir = os.environ.get('TEMPDIR', '/tmp') + '/kilosort2-' + code

        sorter = sorters.Kilosort2Sorter(recording=recording,
                                         output_folder=tmpdir,
                                         debug=True,
                                         delete_output_folder=True)

        sorter.set_params(detect_threshold=self.detect_threshold,
                          car=self.car,
                          minFR=self.minFR,
                          electrode_dimensions=None,
                          freq_min=self.freq_min,
                          sigmaMask=self.adjacency_radius,
                          nPCs=self.pc_per_chan)

        timer = sorter.run()
        print('#SF-SORTER-RUNTIME#{:.3f}#'.format(timer))

        sorting = sorter.get_result()

        SFMdaSortingExtractor.write_sorting(sorting=sorting,
                                            save_path=self.firings_out)
Beispiel #22
0
    def run(self):

        import spikesorters as sorters
        print('KiloSort......')

        try:
            kilosort_path = KiloSort.install()
        except:
            traceback.print_exc()
            raise Exception('Problem installing kilosort.')
        sorters.KilosortSorter.set_kilosort_path(kilosort_path)

        recording = SFMdaRecordingExtractor(self.recording_dir)
        code = ''.join(
            random.choice(string.ascii_uppercase) for x in range(10))
        tmpdir = os.environ.get('TEMPDIR', '/tmp') + '/kilosort-' + code

        sorter = sorters.KilosortSorter(recording=recording,
                                        output_folder=tmpdir,
                                        debug=True,
                                        delete_output_folder=True)

        sorter.set_params(detect_threshold=self.detect_threshold,
                          freq_min=self.freq_min,
                          freq_max=self.freq_max,
                          car=True,
                          useGPU=True,
                          electrode_dimensions=None)

        # TODO: get elapsed time from the return of this run
        sorter.run()

        sorting = sorter.get_result()

        SFMdaSortingExtractor.write_sorting(sorting=sorting,
                                            save_path=self.firings_out)
    def initialize(self):
        if self._initialized:
            return
        self._initialized = True

        # self._recording_context.initialize()

        print('******** FORESTVIEW: Initializing sorting result context')

        if self._sorting_result_object['firings']:
            self._sorting_extractor = SFMdaSortingExtractor(
                firings_file=self._sorting_result_object['firings'])
        else:
            self._sorting_extractor = None

        print('******** FORESTVIEW: Done initializing sorting result context')
import os
import shutil
from spikeforest import example_datasets
from spikeforest import SFMdaRecordingExtractor, SFMdaSortingExtractor

recording, sorting_true = example_datasets.toy_example1()

recdir = 'toy_example1'

# remove the toy recording directory if it exists
if os.path.exists(recdir):
    shutil.rmtree(recdir)

print('Preparing toy recording...')
SFMdaRecordingExtractor.write_recording(recording=recording, save_path=recdir)
SFMdaSortingExtractor.write_sorting(sorting=sorting_true,
                                    save_path=recdir + '/firings_true.mda')
Beispiel #25
0
    def run(self):
        import tridesclous as tdc

        tmpdir = Path(_get_tmpdir('tdc'))
        recording = SFMdaRecordingExtractor(self.recording_dir)

        params = {
            'fullchain_kargs': {
                'duration': 300.,
                'preprocessor': {
                    'highpass_freq': self.freq_min,
                    'lowpass_freq': self.freq_max,
                    'smooth_size': 0,
                    'chunksize': 1024,
                    'lostfront_chunksize': 128,
                    'signalpreprocessor_engine': 'numpy',
                    'common_ref_removal': self.common_ref_removal,
                },
                'peak_detector': {
                    'peakdetector_engine': 'numpy',
                    'peak_sign': '-',
                    'relative_threshold': self.detection_threshold,
                    'peak_span': self.peak_span,
                },
                'noise_snippet': {
                    'nb_snippet': 300,
                },
                'extract_waveforms': {
                    'n_left': self.waveforms_n_left,
                    'n_right': self.waveforms_n_right,
                    'mode': 'rand',
                    'nb_max': 20000,
                    'align_waveform': self.align_waveform,
                },
                'clean_waveforms': {
                    'alien_value_threshold': self.alien_value_threshold,
                },
            },
            'feat_method': 'peak_max',
            'feat_kargs': {},
            'clust_method': 'sawchaincut',
            'clust_kargs': {
                'kde_bandwith': 1.
            },
        }

        # save prb file:
        probe_file = tmpdir / 'probe.prb'
        se.save_probe_file(recording, probe_file, format='spyking_circus')

        # source file
        if isinstance(recording,
                      se.BinDatRecordingExtractor) and recording._frame_first:
            # no need to copy
            raw_filename = recording._datfile
            dtype = recording._timeseries.dtype.str
            nb_chan = len(recording._channels)
            offset = recording._timeseries.offset
        else:
            # save binary file (chunk by hcunk) into a new file
            raw_filename = tmpdir / 'raw_signals.raw'
            n_chan = recording.get_num_channels()
            chunksize = 2**24 // n_chan
            se.write_binary_dat_format(recording,
                                       raw_filename,
                                       time_axis=0,
                                       dtype='float32',
                                       chunksize=chunksize)
            dtype = 'float32'
            offset = 0

        # initialize source and probe file
        tdc_dataio = tdc.DataIO(dirname=str(tmpdir))
        nb_chan = recording.get_num_channels()

        tdc_dataio.set_data_source(
            type='RawData',
            filenames=[str(raw_filename)],
            dtype=dtype,
            sample_rate=recording.get_sampling_frequency(),
            total_channel=nb_chan,
            offset=offset)
        tdc_dataio.set_probe_file(str(probe_file))

        try:
            sorting = tdc_helper(tmpdir=tmpdir,
                                 params=params,
                                 recording=recording)
            SFMdaSortingExtractor.write_sorting(sorting=sorting,
                                                save_path=self.firings_out)
        except:
            if os.path.exists(tmpdir):
                if not getattr(self, '_keep_temp_files', False):
                    shutil.rmtree(tmpdir)
            raise
        if not getattr(self, '_keep_temp_files', False):
            shutil.rmtree(tmpdir)
Beispiel #26
0
 def sorting(self):
     return SFMdaSortingExtractor(firings_file=self._obj['firings'])
Beispiel #27
0
#!/usr/bin/env python

from spikeforest import SFMdaRecordingExtractor, SFMdaSortingExtractor
from mountaintools import client as mt

# Configure to download from the public spikeforest kachery node
mt.configDownloadFrom('spikeforest.public')

# Load an example tetrode recording with its ground truth
# You can also substitute any of the other available recordings
recdir = 'sha1dir://fb52d510d2543634e247e0d2d1d4390be9ed9e20.synth_magland/datasets_noise10_K10_C4/001_synth'

print('loading recording...')
recording = SFMdaRecordingExtractor(dataset_directory=recdir, download=True)
sorting_true = SFMdaSortingExtractor(firings_file=recdir + '/firings_true.mda')

# import a spike sorter from the spikesorters module of spikeforest
from spikeforestsorters import MountainSort4
import os
import shutil

# In place of MountainSort4 you could use any of the following:
#
# MountainSort4, SpykingCircus, KiloSort, KiloSort2, YASS
# IronClust, HerdingSpikes2, JRClust, Tridesclous, Klusta
# although the Matlab sorters require further setup.

# clear and create an empty output directory (keep things tidy)
if os.path.exists('test_outputs'):
    shutil.rmtree('test_outputs')
os.makedirs('test_outputs', exist_ok=True)
Beispiel #28
0
 def sortingTrue(self):
     return SFMdaSortingExtractor(firings_file=self.directory() +
                                  '/firings_true.mda')
Beispiel #29
0
from spikeforest import SFMdaRecordingExtractor, SFMdaSortingExtractor
from mountaintools import client as mt

# Configure to download from the public spikeforest kachery node
mt.configDownloadFrom('spikeforest.public')

# Load the recording with its ground truth
recdir = 'sha1dir://be6ce9f60fe1963af235862dc8197c9753b4b6f5.hybrid_janelia/drift_siprobe/rec_16c_1200s_11'

print('Loading recording...')
recording = SFMdaRecordingExtractor(dataset_directory=recdir, download=True)
sorting_true = SFMdaSortingExtractor(firings_file=recdir + '/firings_true.mda')

sorting_ms4 = SFMdaSortingExtractor(
    firings_file=
    'sha1://f1c6fdf52a2873d6f746e44dab6bf7ccd2937d97/f1c6fdf52a2873d6f746e44dab6bf7ccd2937d97/firings.mda'
)

# import from the spikeforest package
import spikeforest_analysis as sa

# write the ground truth firings file
SFMdaSortingExtractor.write_sorting(sorting=sorting_true,
                                    save_path='test_outputs/firings_true.mda')

# run the comparison
print('Compare with truth...')
import time
timer = time.time()

## Old method
Beispiel #30
0
from spikeforest import SFMdaRecordingExtractor, SFMdaSortingExtractor
from mountaintools import client as mt

# Configure to download from the public spikeforest kachery node
mt.configDownloadFrom('spikeforest.public')

# Load an example tetrode recording with its ground truth
# You can also substitute any of the other available recordings
recdir = 'sha1dir://fb52d510d2543634e247e0d2d1d4390be9ed9e20.synth_magland/datasets_noise10_K10_C4/001_synth'

print('loading recording...')
recording = SFMdaRecordingExtractor(dataset_directory=recdir, download=True)
sorting_true = SFMdaSortingExtractor(firings_file=recdir + '/firings_true.mda')