Beispiel #1
0
    def javascript_state_changed(self, prev_state, state):
        self._set_status('running', 'Running Recording')
        if not self._recording:
            self._set_status('running', 'Loading recording')
            recording0 = state.get('recording', None)
            if not recording0:
                self._set_error('Missing: recording')
                return
            try:
                self._recording = AutoRecordingExtractor(recording0)
            except Exception as err:
                traceback.print_exc()
                self._set_error('Problem initiating recording: {}'.format(err))
                return

            self._set_status('running', 'Loading recording data')            
            try:
                channel_locations = self._recording.get_channel_locations()
            except:
                channel_locations = None
            self.set_state(dict(
                num_channels=self._recording.get_num_channels(),
                channel_ids=self._recording.get_channel_ids(),
                channel_locations=channel_locations,
                num_timepoints=self._recording.get_num_frames(),
                samplerate=self._recording.get_sampling_frequency(),
                status_message='Loaded recording.'
            ))
        self._set_status('finished', '')
Beispiel #2
0
    def javascript_state_changed(self, prev_state, state):
        self._set_status('running', 'Running TimeseriesView')
        self._create_efficient_access = state.get('create_efficient_access',
                                                  False)
        if not self._recording:
            self._set_status('running', 'Loading recording')
            recording0 = state.get('recording', None)
            if not recording0:
                self._set_error('Missing: recording')
                return
            try:
                self._recording = AutoRecordingExtractor(recording0)
            except Exception as err:
                traceback.print_exc()
                self._set_error('Problem initiating recording: {}'.format(err))
                return

            self._set_status('running', 'Loading recording data')
            traces0 = self._recording.get_traces(
                channel_ids=self._recording.get_channel_ids(),
                start_frame=0,
                end_frame=min(self._recording.get_num_frames(), 25000))
            y_offsets = -np.mean(traces0, axis=1)
            for m in range(traces0.shape[0]):
                traces0[m, :] = traces0[m, :] + y_offsets[m]
            vv = np.percentile(np.abs(traces0), 90)
            y_scale_factor = 1 / (2 * vv) if vv > 0 else 1
            self._segment_size = int(
                np.ceil(self._segment_size_times_num_channels /
                        self._recording.get_num_channels()))
            try:
                channel_locations = self._recording.get_channel_locations()
            except:
                channel_locations = None
            self.set_state(
                dict(num_channels=self._recording.get_num_channels(),
                     channel_ids=self._recording.get_channel_ids(),
                     channel_locations=channel_locations,
                     num_timepoints=self._recording.get_num_frames(),
                     y_offsets=y_offsets,
                     y_scale_factor=y_scale_factor,
                     samplerate=self._recording.get_sampling_frequency(),
                     segment_size=self._segment_size,
                     status_message='Loaded recording.'))

        # SR = state.get('segmentsRequested', {})
        # for key in SR.keys():
        #     aa = SR[key]
        #     if not self.get_python_state(key, None):
        #         self.set_state(dict(status_message='Loading segment {}'.format(key)))
        #         data0 = self._load_data(aa['ds'], aa['ss'])
        #         data0_base64 = _mda32_to_base64(data0)
        #         state0 = {}
        #         state0[key] = dict(data=data0_base64, ds=aa['ds'], ss=aa['ss'])
        #         self.set_state(state0)
        #         self.set_state(dict(status_message='Loaded segment {}'.format(key)))
        self._set_status('finished', '')
def herdingspikes2(recording_path,
                   sorting_out,
                   filter=True,
                   pre_scale=True,
                   pre_scale_value=20):
    import spiketoolkit as st
    import spikesorters as ss
    from spikeforest2_utils import AutoRecordingExtractor, AutoSortingExtractor

    recording = AutoRecordingExtractor(dict(path=recording_path),
                                       download=True)

    # Sorting
    print('Sorting...')

    output_folder = '/tmp/tmp_herdingspikes2_' + _random_string(8)
    os.environ[
        'HS2_PROBE_PATH'] = output_folder  # important for when we are in a container
    sorter = ss.HerdingspikesSorter(recording=recording,
                                    output_folder=output_folder,
                                    delete_output_folder=True)

    num_workers = os.environ.get('NUM_WORKERS', None)
    if not num_workers: num_workers = '1'
    num_workers = int(num_workers)

    sorter.set_params(filter=filter,
                      pre_scale=pre_scale,
                      pre_scale_value=pre_scale_value,
                      clustering_n_jobs=num_workers)
    timer = sorter.run()
    print('#SF-SORTER-RUNTIME#{:.3f}#'.format(timer))
    sorting = sorter.get_result()

    AutoSortingExtractor.write_sorting(sorting=sorting, save_path=sorting_out)
Beispiel #4
0
def waveclus(
    recording_path,
    sorting_out
):
    from spikeforest2_utils import AutoRecordingExtractor, AutoSortingExtractor
    from ._waveclussorter import WaveclusSorter

    recording = AutoRecordingExtractor(dict(path=recording_path), download=True)

    # recording = se.SubRecordingExtractor(parent_recording=recording, start_frame=0, end_frame=30000 * 10)
    
    # Sorting
    print('Sorting...')
    sorter = WaveclusSorter(
        recording=recording,
        output_folder='/tmp/tmp_waveclus_' + _random_string(8),
        delete_output_folder=True
    )

    sorter.set_params(
    )

    timer = sorter.run()
    print('#SF-SORTER-RUNTIME#{:.3f}#'.format(timer))
    sorting = sorter.get_result()

    AutoSortingExtractor.write_sorting(sorting=sorting, save_path=sorting_out)
Beispiel #5
0
def kilosort2(recording, sorting_out):
    from spikeforest2_utils import AutoRecordingExtractor, AutoSortingExtractor
    from ._kilosort2sorter import Kilosort2Sorter
    import kachery as ka

    # TODO: need to think about how to deal with this
    ka.set_config(fr='default_readonly')

    recording = AutoRecordingExtractor(dict(path=recording), download=True)

    # recording = se.SubRecordingExtractor(parent_recording=recording, start_frame=0, end_frame=30000 * 10)
    
    # Sorting
    print('Sorting...')
    sorter = Kilosort2Sorter(
        recording=recording,
        output_folder='/tmp/tmp_kilosort2_' + _random_string(8),
        delete_output_folder=True
    )

    sorter.set_params(
        detect_sign=-1,
        detect_threshold=5,
        freq_min=150,
        pc_per_chan=3
    )     
    timer = sorter.run()
    print('#SF-SORTER-RUNTIME#{:.3f}#'.format(timer))
    sorting = sorter.get_result()

    AutoSortingExtractor.write_sorting(sorting=sorting, save_path=sorting_out)
Beispiel #6
0
def tridesclous(
    recording_path,
    sorting_out
):
    import spiketoolkit as st
    import spikesorters as ss
    from spikeforest2_utils import AutoRecordingExtractor, AutoSortingExtractor

    recording = AutoRecordingExtractor(dict(path=recording_path), download=True)
    
    # Sorting
    print('Sorting...')

    output_folder = '/tmp/tmp_tridesclous_' + _random_string(8)
    os.environ['HS2_PROBE_PATH'] = output_folder # important for when we are in a container
    sorter = ss.TridesclousSorter(
        recording=recording,
        output_folder=output_folder,
        delete_output_folder=True,
        verbose=True,
    )

    # num_workers = os.environ.get('NUM_WORKERS', None)
    # if not num_workers: num_workers='1'
    # num_workers = int(num_workers)

    sorter.set_params(
    )
    timer = sorter.run()
    print('#SF-SORTER-RUNTIME#{:.3f}#'.format(timer))
    sorting = sorter.get_result()

    AutoSortingExtractor.write_sorting(sorting=sorting, save_path=sorting_out)
Beispiel #7
0
def mountainsort4(recording: str, sorting_out: str) -> str:
    import spiketoolkit as st
    import spikesorters as ss
    from spikeforest2_utils import AutoRecordingExtractor, AutoSortingExtractor
    import kachery as ka

    # TODO: need to think about how to deal with this
    ka.set_config(fr='default_readonly')

    recording = AutoRecordingExtractor(dict(path=recording), download=True)

    # recording = se.SubRecordingExtractor(parent_recording=recording, start_frame=0, end_frame=30000 * 10)

    # Preprocessing
    print('Preprocessing...')
    recording = st.preprocessing.bandpass_filter(recording,
                                                 freq_min=300,
                                                 freq_max=6000)
    recording = st.preprocessing.whiten(recording)

    # Sorting
    print('Sorting...')
    sorter = ss.Mountainsort4Sorter(recording=recording,
                                    output_folder='/tmp/tmp_mountainsort4_' +
                                    _random_string(8),
                                    delete_output_folder=True)

    sorter.set_params(detect_sign=-1, adjacency_radius=50, detect_threshold=4)
    timer = sorter.run()
    print('#SF-SORTER-RUNTIME#{:.3f}#'.format(timer))
    sorting = sorter.get_result()

    AutoSortingExtractor.write_sorting(sorting=sorting, save_path=sorting_out)
Beispiel #8
0
def kilosort(
        recording_path,
        sorting_out,
        detect_threshold=6,
        freq_min=300,
        freq_max=6000,
        Nt=128 * 1024 * 5 + 64  # batch size for kilosort
):
    from spikeforest2_utils import AutoRecordingExtractor, AutoSortingExtractor
    from ._kilosortsorter import KilosortSorter

    recording = AutoRecordingExtractor(dict(path=recording_path),
                                       download=True)

    # recording = se.SubRecordingExtractor(parent_recording=recording, start_frame=0, end_frame=30000 * 10)

    # Sorting
    print('Sorting...')
    sorter = KilosortSorter(recording=recording,
                            output_folder='/tmp/tmp_kilosort_' +
                            _random_string(8),
                            delete_output_folder=True)

    sorter.set_params(detect_threshold=detect_threshold,
                      freq_min=freq_min,
                      freq_max=freq_max,
                      car=True)

    timer = sorter.run()
    print('#SF-SORTER-RUNTIME#{:.3f}#'.format(timer))
    sorting = sorter.get_result()

    AutoSortingExtractor.write_sorting(sorting=sorting, save_path=sorting_out)
def mountainsort4(
    recording_path: str,
    sorting_out: str,
    detect_sign=-1,
    adjacency_radius=50,
    clip_size=50,
    detect_threshold=3,
    detect_interval=10,
    freq_min=300,
    freq_max=6000
):
    import spiketoolkit as st
    import spikesorters as ss
    from spikeforest2_utils import AutoRecordingExtractor, AutoSortingExtractor

    recording = AutoRecordingExtractor(dict(path=recording_path), download=True)

    # for quick testing
    # import spikeextractors as se
    # recording = se.SubRecordingExtractor(parent_recording=recording, start_frame=0, end_frame=30000 * 1)
    
    # Preprocessing
    # print('Preprocessing...')
    # recording = st.preprocessing.bandpass_filter(recording, freq_min=300, freq_max=6000)
    # recording = st.preprocessing.whiten(recording)

    # Sorting
    print('Sorting...')
    sorter = ss.Mountainsort4Sorter(
        recording=recording,
        output_folder='/tmp/tmp_mountainsort4_' + _random_string(8),
        delete_output_folder=True
    )

    num_workers = os.environ.get('NUM_WORKERS', None)
    if num_workers:
        num_workers = int(num_workers)
    else:
        num_workers = 0

    sorter.set_params(
        detect_sign=detect_sign,
        adjacency_radius=adjacency_radius,
        clip_size=clip_size,
        detect_threshold=detect_threshold,
        detect_interval=detect_interval,
        num_workers=num_workers,
        curation=False,
        whiten=True,
        filter=True,
        freq_min=freq_min,
        freq_max=freq_max
    )     
    timer = sorter.run()
    print('#SF-SORTER-RUNTIME#{:.3f}#'.format(timer))
    sorting = sorter.get_result()

    AutoSortingExtractor.write_sorting(sorting=sorting, save_path=sorting_out)
def filter_recording(recording_directory, timeseries_out):
    from spikeforest2_utils import AutoRecordingExtractor
    from spikeforest2_utils import writemda32
    import spiketoolkit as st
    rx = AutoRecordingExtractor(recording_directory)
    rx2 = st.preprocessing.bandpass_filter(recording=rx,
                                           freq_min=300,
                                           freq_max=6000,
                                           freq_wid=1000)
    if not writemda32(rx2.get_traces(), timeseries_out):
        raise Exception('Unable to write output file.')
Beispiel #11
0
def jrclust(
    recording_path,
    sorting_out,
    detect_sign=-1, # Use -1, 0, or 1, depending on the sign of the spikes in the recording')
    adjacency_radius=50,
    detect_threshold=4.5, # detection threshold
    freq_min=300,
    freq_max=3000,
    merge_thresh=0.98,
    pc_per_chan=1,
    filter_type='bandpass', # {none, bandpass, wiener, fftdiff, ndiff}
    nDiffOrder='none',
    min_count=30,
    fGpu=0,
    fParfor=0,
    feature_type='gpca' #  # {gpca, pca, vpp, vmin, vminmax, cov, energy, xcov}')
):
    from spikeforest2_utils import AutoRecordingExtractor, AutoSortingExtractor
    from ._jrclustsorter import JRClustSorter

    recording = AutoRecordingExtractor(dict(path=recording_path), download=True)

    # recording = se.SubRecordingExtractor(parent_recording=recording, start_frame=0, end_frame=30000 * 10)
    
    # Sorting
    print('Sorting...')
    sorter = JRClustSorter(
        recording=recording,
        output_folder='/tmp/tmp_jrclust_' + _random_string(8),
        delete_output_folder=True
    )

    sorter.set_params(
        detect_sign=detect_sign,
        adjacency_radius=adjacency_radius,
        detect_threshold=detect_threshold,
        freq_min=freq_min,
        freq_max=freq_max,
        merge_thresh=merge_thresh,
        pc_per_chan=pc_per_chan,
        filter_type=filter_type,
        nDiffOrder=nDiffOrder,
        min_count=min_count,
        fGpu=fGpu,
        fParfor=fParfor,
        feature_type=feature_type
    )

    timer = sorter.run()
    print('#SF-SORTER-RUNTIME#{:.3f}#'.format(timer))
    sorting = sorter.get_result()

    AutoSortingExtractor.write_sorting(sorting=sorting, save_path=sorting_out)
def filter_recording(recobj, freq_min=300, freq_max=6000, freq_wid=1000):
    from spikeforest2_utils import AutoRecordingExtractor
    from spikeforest2_utils import writemda32
    import spiketoolkit as st
    rx = AutoRecordingExtractor(recobj)
    rx2 = st.preprocessing.bandpass_filter(recording=rx, freq_min=freq_min, freq_max=freq_max, freq_wid=freq_wid)
    recobj2 = recobj.copy()
    with hither.TemporaryDirectory() as tmpdir:
        raw_fname = tmpdir + '/raw.mda'
        if not writemda32(rx2.get_traces(), raw_fname):
            raise Exception('Unable to write output file.')
        recobj2['raw'] = ka.store_file(raw_fname)
        return recobj2
def compute_recording_info(recording_path, json_out):
    recording = AutoRecordingExtractor(recording_path)
    obj = dict(samplerate=recording.get_sampling_frequency(),
               num_channels=len(recording.get_channel_ids()),
               duration_sec=recording.get_num_frames() /
               recording.get_sampling_frequency())
    with open(json_out, 'w') as f:
        json.dump(obj, f)
Beispiel #14
0
def ironclust(recording, sorting_out):
    from spikeforest2_utils import AutoRecordingExtractor, AutoSortingExtractor
    from ._ironclustsorter import IronClustSorter
    import kachery as ka

    # TODO: need to think about how to deal with this
    ka.set_config(fr='default_readonly')

    recording = AutoRecordingExtractor(dict(path=recording), download=True)

    # Sorting
    print('Sorting...')
    sorter = IronClustSorter(recording=recording,
                             output_folder='/tmp/tmp_ironclust_' +
                             _random_string(8),
                             delete_output_folder=True)

    sorter.set_params(detect_sign=-1,
                      adjacency_radius=50,
                      adjacency_radius_out=75,
                      detect_threshold=4,
                      prm_template_name='',
                      freq_min=300,
                      freq_max=8000,
                      merge_thresh=0.99,
                      pc_per_chan=0,
                      whiten=False,
                      filter_type='bandpass',
                      filter_detect_type='none',
                      common_ref_type='mean',
                      batch_sec_drift=300,
                      step_sec_drift=20,
                      knn=30,
                      min_count=30,
                      fGpu=True,
                      fft_thresh=8,
                      fft_thresh_low=0,
                      nSites_whiten=32,
                      feature_type='gpca',
                      delta_cut=1,
                      post_merge_mode=1,
                      sort_mode=1)
    timer = sorter.run()
    print('#SF-SORTER-RUNTIME#{:.3f}#'.format(timer))
    sorting = sorter.get_result()

    AutoSortingExtractor.write_sorting(sorting=sorting, save_path=sorting_out)
Beispiel #15
0
def spykingcircus(recording_path,
                  sorting_out,
                  detect_sign=-1,
                  adjacency_radius=200,
                  detect_threshold=6,
                  template_width_ms=3,
                  filter=True,
                  merge_spikes=True,
                  auto_merge=0.75,
                  whitening_max_elts=1000,
                  clustering_max_elts=10000):
    import spiketoolkit as st
    import spikesorters as ss
    from spikeforest2_utils import AutoRecordingExtractor, AutoSortingExtractor

    recording = AutoRecordingExtractor(dict(path=recording_path),
                                       download=True)

    # Sorting
    print('Sorting...')
    sorter = ss.SpykingcircusSorter(recording=recording,
                                    output_folder='/tmp/tmp_spykingcircus_' +
                                    _random_string(8),
                                    delete_output_folder=True)

    num_workers = os.environ.get('NUM_WORKERS', None)
    if not num_workers: num_workers = '1'
    num_workers = int(num_workers)

    sorter.set_params(detect_sign=detect_sign,
                      adjacency_radius=adjacency_radius,
                      detect_threshold=detect_threshold,
                      template_width_ms=template_width_ms,
                      filter=filter,
                      merge_spikes=merge_spikes,
                      auto_merge=auto_merge,
                      num_workers=num_workers,
                      whitening_max_elts=whitening_max_elts,
                      clustering_max_elts=clustering_max_elts)
    timer = sorter.run()
    print('#SF-SORTER-RUNTIME#{:.3f}#'.format(timer))
    sorting = sorter.get_result()

    AutoSortingExtractor.write_sorting(sorting=sorting, save_path=sorting_out)
Beispiel #16
0
def klusta(
    recording_path,
    sorting_out,
    adjacency_radius=None,
    detect_sign=-1,
    threshold_strong_std_factor=5,
    threshold_weak_std_factor=2,
    n_features_per_channel=3,
    num_starting_clusters=3,
    extract_s_before=16,
    extract_s_after=32
):
    import spikesorters as ss
    from spikeforest2_utils import AutoRecordingExtractor, AutoSortingExtractor

    recording = AutoRecordingExtractor(dict(path=recording_path), download=True)
    
    # Sorting
    print('Sorting...')
    sorter = ss.KlustaSorter(
        recording=recording,
        output_folder='/tmp/tmp_klusta_' + _random_string(8),
        delete_output_folder=True
    )

    # num_workers = os.environ.get('NUM_WORKERS', None)
    # if not num_workers: num_workers='1'
    # num_workers = int(num_workers)

    sorter.set_params(
        adjacency_radius=adjacency_radius,
        detect_sign=detect_sign,
        threshold_strong_std_factor=threshold_strong_std_factor,
        threshold_weak_std_factor=threshold_weak_std_factor,
        n_features_per_channel=n_features_per_channel,
        num_starting_clusters=num_starting_clusters,
        extract_s_before=extract_s_before,
        extract_s_after=extract_s_after
    )     
    timer = sorter.run()
    print('#SF-SORTER-RUNTIME#{:.3f}#'.format(timer))
    sorting = sorter.get_result()

    AutoSortingExtractor.write_sorting(sorting=sorting, save_path=sorting_out)
def prepare_dataset_from_hash(
    recording_paths: Union[str, List[str]],
    gt_paths: Union[str, List[str]],
    sorter_names: List[str],
    metric_names: List[str],
    cache_path: Path,
):
    if isinstance(recording_paths, str):
        recording_paths = [recording_paths]

    if isinstance(gt_paths, str):
        gt_paths = [gt_paths]

    if len(recording_paths) != len(gt_paths):
        raise ValueError(
            f"You have provided {len(recording_paths)} recording hashes and {len(gt_paths)} ground truth hashes! These must be the same."
        )

    all_X = []
    all_y = []
    for i in range(len(recording_paths)):
        recording_path = recording_paths[i]
        gt_path = gt_paths[i]

        c_path = cache_path / recording_path.split('//')[1]

        recording = AutoRecordingExtractor(recording_path, download=True)
        gt_sorting = AutoSortingExtractor(gt_path)

        session = SpikeSession(recording, gt_sorting, cache_path=c_path)

        X, y = prepare_fp_dataset(session,
                                  sorter_names=sorter_names,
                                  metric_names=metric_names)

        all_X.append(X)
        all_y.append(y)

    return np.vstack(all_X), np.hstack(all_y)
Beispiel #18
0
def load_spikeforest_data(recording_path: str,
                          sorting_true_path: str,
                          download=True):
    recording = AutoRecordingExtractor(recording_path, download=download)
    sorting_GT = AutoSortingExtractor(sorting_true_path)
    # recording info
    fs = recording.get_sampling_frequency()
    channel_ids = recording.get_channel_ids()
    channel_loc = recording.get_channel_locations()
    num_frames = recording.get_num_frames()
    duration = recording.frame_to_time(num_frames)
    print(f'Sampling frequency:{fs}')
    print(f'Channel ids:{channel_ids}')
    print(f'channel location:{channel_loc}')
    print(f'frame num:{num_frames}')
    print(f'recording duration:{duration}')
    # sorting_GT info
    unit_ids = sorting_GT.get_unit_ids()
    print(f'unit ids:{unit_ids}')
    return recording, sorting_GT
Beispiel #19
0
def kilosort2(
        recording_path,
        sorting_out,
        detect_threshold=6,
        car=True,  # whether to do common average referencing
        minFR=1 /
    50,  # minimum spike rate (Hz), if a cluster falls below this for too long it gets removed
        freq_min=150,  # min. bp filter freq (Hz), use 0 for no filter
        sigmaMask=30,  # sigmaMask
        nPCs=3,  # PCs per channel?
):
    from spikeforest2_utils import AutoRecordingExtractor, AutoSortingExtractor
    from ._kilosort2sorter import Kilosort2Sorter

    recording = AutoRecordingExtractor(dict(path=recording_path),
                                       download=True)

    # recording = se.SubRecordingExtractor(parent_recording=recording, start_frame=0, end_frame=30000 * 10)

    # Sorting
    print('Sorting...')
    sorter = Kilosort2Sorter(recording=recording,
                             output_folder='/tmp/tmp_kilosort2_' +
                             _random_string(8),
                             delete_output_folder=True)

    sorter.set_params(detect_threshold=detect_threshold,
                      car=car,
                      minFR=minFR,
                      freq_min=freq_min,
                      sigmaMask=sigmaMask,
                      nPCs=nPCs)

    timer = sorter.run()
    print('#SF-SORTER-RUNTIME#{:.3f}#'.format(timer))
    sorting = sorter.get_result()

    AutoSortingExtractor.write_sorting(sorting=sorting, save_path=sorting_out)
Beispiel #20
0
def spykingcircus(recording, sorting_out):
    import spiketoolkit as st
    import spikesorters as ss
    from spikeforest2_utils import AutoRecordingExtractor, AutoSortingExtractor
    import kachery as ka

    # TODO: need to think about how to deal with this
    ka.set_config(fr='default_readonly')

    recording = AutoRecordingExtractor(dict(path=recording), download=True)

    # Sorting
    print('Sorting...')
    sorter = ss.SpykingcircusSorter(recording=recording,
                                    output_folder='/tmp/tmp_spykingcircus_' +
                                    _random_string(8),
                                    delete_output_folder=True)

    sorter.set_params()
    timer = sorter.run()
    print('#SF-SORTER-RUNTIME#{:.3f}#'.format(timer))
    sorting = sorter.get_result()

    AutoSortingExtractor.write_sorting(sorting=sorting, save_path=sorting_out)
Beispiel #21
0
from pathlib import Path

import numpy as np

import kachery as ka
from pykilosort import Bunch, add_default_handler, run
from spikeextractors.extractors import bindatrecordingextractor as dat
from spikeforest2_utils import AutoRecordingExtractor, AutoSortingExtractor

dat_path = Path("test/test.bin").absolute()
dir_path = dat_path.parent

ka.set_config(fr="default_readonly")
recording_path = "sha1dir://c0879a26f92e4c876cd608ca79192a84d4382868.manual_franklab/tetrode_600s/sorter1_1"
recording = AutoRecordingExtractor(recording_path, download=True)
recording.write_to_binary_dat_format(str(dat_path))
n_channels = len(recording.get_channel_ids())

probe = Bunch()
probe.NchanTOT = n_channels
probe.chanMap = np.array(range(0, n_channels))
probe.kcoords = np.ones(n_channels)
probe.xc = recording.get_channel_locations()[:, 0]
probe.yc = recording.get_channel_locations()[:, 1]

add_default_handler(level="DEBUG")

params = {"nfilt_factor": 8, "AUCsplit": 0.85, "nskip": 5}

run(
    dat_path,
def main():
    import spikeextractors as se
    from spikeforest2_utils import writemda32, AutoRecordingExtractor
    from sklearn.neighbors import NearestNeighbors
    from sklearn.cross_decomposition import PLSRegression
    import spikeforest_widgets as sw
    sw.init_electron()

    # bandpass filter
    with hither.config(container='default', cache='default_readwrite'):
        recobj2 = filter_recording.run(
            recobj=recobj,
            freq_min=300,
            freq_max=6000,
            freq_wid=1000
        ).retval
    
    detect_threshold = 3
    detect_interval = 200
    detect_interval_reference = 10
    detect_sign = -1
    num_events = 1000
    snippet_len = (200, 200)
    window_frac = 0.3
    num_passes = 20
    npca = 100
    max_t = 30000 * 100
    k = 20
    ncomp = 4
    
    R = AutoRecordingExtractor(recobj2)

    X = R.get_traces()
    
    sig = X.copy()
    if detect_sign < 0:
        sig = -sig
    elif detect_sign == 0:
        sig = np.abs(sig)
    sig = np.max(sig, axis=0)
    noise_level = np.median(np.abs(sig)) / 0.6745  # median absolute deviation (MAD)
    times_reference = detect_on_channel(sig, detect_threshold=noise_level*detect_threshold, detect_interval=detect_interval_reference, detect_sign=1, margin=1000)
    times_reference = times_reference[times_reference <= max_t]
    print(f'Num. reference events = {len(times_reference)}')

    snippets_reference = extract_snippets(X, reference_frames=times_reference, snippet_len=snippet_len)
    tt = np.linspace(-1, 1, snippets_reference.shape[2])
    window0 = np.exp(-tt**2/(2*window_frac**2))
    for j in range(snippets_reference.shape[0]):
        for m in range(snippets_reference.shape[1]):
            snippets_reference[j, m, :] = snippets_reference[j, m, :] * window0
    A_snippets_reference = snippets_reference.reshape(snippets_reference.shape[0], snippets_reference.shape[1] * snippets_reference.shape[2])

    print('PCA...')
    u, s, vh = np.linalg.svd(A_snippets_reference)
    components_reference = vh[0:npca, :].T
    features_reference = A_snippets_reference @ components_reference

    print('Setting up nearest neighbors...')
    nbrs = NearestNeighbors(n_neighbors=k + 1, algorithm='ball_tree').fit(features_reference)

    X_signal = np.zeros((R.get_num_channels(), R.get_num_frames()), dtype=np.float32)

    for passnum in range(num_passes):
        print(f'Pass {passnum}')
        sig = X.copy()
        if detect_sign < 0:
            sig = -sig
        elif detect_sign == 0:
            sig = np.abs(sig)
        sig = np.max(sig, axis=0)
        noise_level = np.median(np.abs(sig)) / 0.6745  # median absolute deviation (MAD)
        times = detect_on_channel(sig, detect_threshold=noise_level*detect_threshold, detect_interval=detect_interval, detect_sign=1, margin=1000)
        times = times[times <= max_t]
        print(f'Number of events: {len(times)}')
        if len(times) == 0:
            break
        snippets = extract_snippets(X, reference_frames=times, snippet_len=snippet_len)
        for j in range(snippets.shape[0]):
            for m in range(snippets.shape[1]):
                snippets[j, m, :] = snippets[j, m, :] * window0
        A_snippets = snippets.reshape(snippets.shape[0], snippets.shape[1] * snippets.shape[2])
        features = A_snippets @ components_reference
        
        print('Finding nearest neighbors...')
        distances, indices = nbrs.kneighbors(features)
        features2 = np.zeros(features.shape, dtype=features.dtype)
        print('PLS regression...')
        for j in range(features.shape[0]):
            print(f'{j+1} of {features.shape[0]}')
            inds0 = np.squeeze(indices[j, :])
            inds0 = inds0[1:] # TODO: it may not always be necessary to exclude the first -- how should we make that decision?
            f_neighbors = features_reference[inds0, :]
            pls = PLSRegression(n_components=ncomp)
            pls.fit(f_neighbors.T, features[j, :].T)
            features2[j, :] = pls.predict(f_neighbors.T).T
        A_snippets_denoised = features2 @ components_reference.T
        
        snippets_denoised = A_snippets_denoised.reshape(snippets.shape)

        for j in range(len(times)):
            t0 = times[j]
            snippet_denoised_0 = np.squeeze(snippets_denoised[j, :, :])
            X_signal[:, t0-snippet_len[0]:t0+snippet_len[1]] = X_signal[:, t0-snippet_len[0]:t0+snippet_len[1]] + snippet_denoised_0
            X[:, t0-snippet_len[0]:t0+snippet_len[1]] = X[:, t0-snippet_len[0]:t0+snippet_len[1]] - snippet_denoised_0

    S = np.concatenate((X_signal, X, R.get_traces()), axis=0)

    with hither.TemporaryDirectory() as tmpdir:
        raw_fname = tmpdir + '/raw.mda'
        writemda32(S, raw_fname)
        sig_recobj = recobj2.copy()
        sig_recobj['raw'] = ka.store_file(raw_fname)
    
    sw.TimeseriesView(recording=AutoRecordingExtractor(sig_recobj)).show()
Beispiel #23
0
class TimeseriesView:
    def __init__(self):
        super().__init__()
        self._recording = None
        self._multiscale_recordings = None
        self._segment_size_times_num_channels = 1000000
        self._segment_size = None

    def javascript_state_changed(self, prev_state, state):
        self._set_status('running', 'Running TimeseriesView')
        self._create_efficient_access = state.get('create_efficient_access',
                                                  False)
        if not self._recording:
            self._set_status('running', 'Loading recording')
            recording0 = state.get('recording', None)
            if not recording0:
                self._set_error('Missing: recording')
                return
            try:
                self._recording = AutoRecordingExtractor(recording0)
            except Exception as err:
                traceback.print_exc()
                self._set_error('Problem initiating recording: {}'.format(err))
                return

            self._set_status('running', 'Loading recording data')
            traces0 = self._recording.get_traces(
                channel_ids=self._recording.get_channel_ids(),
                start_frame=0,
                end_frame=min(self._recording.get_num_frames(), 25000))
            y_offsets = -np.mean(traces0, axis=1)
            for m in range(traces0.shape[0]):
                traces0[m, :] = traces0[m, :] + y_offsets[m]
            vv = np.percentile(np.abs(traces0), 90)
            y_scale_factor = 1 / (2 * vv) if vv > 0 else 1
            self._segment_size = int(
                np.ceil(self._segment_size_times_num_channels /
                        self._recording.get_num_channels()))
            try:
                channel_locations = self._recording.get_channel_locations()
            except:
                channel_locations = None
            self.set_state(
                dict(num_channels=self._recording.get_num_channels(),
                     channel_ids=self._recording.get_channel_ids(),
                     channel_locations=channel_locations,
                     num_timepoints=self._recording.get_num_frames(),
                     y_offsets=y_offsets,
                     y_scale_factor=y_scale_factor,
                     samplerate=self._recording.get_sampling_frequency(),
                     segment_size=self._segment_size,
                     status_message='Loaded recording.'))

        # SR = state.get('segmentsRequested', {})
        # for key in SR.keys():
        #     aa = SR[key]
        #     if not self.get_python_state(key, None):
        #         self.set_state(dict(status_message='Loading segment {}'.format(key)))
        #         data0 = self._load_data(aa['ds'], aa['ss'])
        #         data0_base64 = _mda32_to_base64(data0)
        #         state0 = {}
        #         state0[key] = dict(data=data0_base64, ds=aa['ds'], ss=aa['ss'])
        #         self.set_state(state0)
        #         self.set_state(dict(status_message='Loaded segment {}'.format(key)))
        self._set_status('finished', '')

    def on_message(self, msg):

        if msg['command'] == 'requestSegment':
            ds = msg['ds_factor']
            ss = msg['segment_num']
            data0 = self._load_data(ds, ss)
            data0_base64 = _mda32_to_base64(data0)
            self.send_message(
                dict(command='setSegment',
                     ds_factor=ds,
                     segment_num=ss,
                     data=data0_base64))

    def _load_data(self, ds, ss):
        if not self._recording:
            return
        logger.info('_load_data {} {}'.format(ds, ss))
        if ds > 1:
            if self._multiscale_recordings is None:
                self.set_state(
                    dict(status_message='Creating multiscale recordings...'))
                self._multiscale_recordings = _create_multiscale_recordings(
                    recording=self._recording,
                    progressive_ds_factor=3,
                    create_efficient_access=self._create_efficient_access)
                self.set_state(
                    dict(status_message='Done creating multiscale recording'))
            rx = self._multiscale_recordings[ds]
            # print('_extract_data_segment', ds, ss, self._segment_size)
            start_time = time.time()
            X = _extract_data_segment(recording=rx,
                                      segment_num=ss,
                                      segment_size=self._segment_size * 2)
            # print('done extracting data segment', time.time() - start_time)
            logger.info('extracted data segment {} {} {}'.format(
                ds, ss,
                time.time() - start_time))
            return X

        start_time = time.time()
        traces = self._recording.get_traces(
            start_frame=ss * self._segment_size,
            end_frame=(ss + 1) * self._segment_size)
        logger.info('extracted data segment {} {} {}'.format(
            ds, ss,
            time.time() - start_time))
        return traces

    def iterate(self):
        pass

    def _set_state(self, **kwargs):
        self.set_state(kwargs)

    def _set_error(self, error_message):
        self._set_status('error', error_message)

    def _set_status(self, status, status_message=''):
        self._set_state(status=status, status_message=status_message)
Beispiel #24
0


import argparse
import spikeforest_widgets as sw
from spikeforest2_utils import AutoRecordingExtractor
import kachery as ka

sw.init_electron()

ka.set_config(fr='default_readonly')

parser = argparse.ArgumentParser(description='Browse a SpikeForest analysis')
# parser.add_argument('--path', help='Path to the analysis JSON file', required=True)

args = parser.parse_args()

R = AutoRecordingExtractor('sha1dir://49b1fe491cbb4e0f90bde9cfc31b64f985870528.paired_boyden32c/531_2_1')

X = sw.TimeseriesView(
    recording = R
)
X.show()
def compute_units_info(recording_path, sorting_path, json_out):
    recording = AutoRecordingExtractor(recording_path)
    sorting = AutoSortingExtractor(sorting_path, samplerate=recording.get_sampling_frequency())
    obj = _compute_units_info(recording=recording, sorting=sorting)
    with open(json_out, 'w') as f:
        json.dump(obj, f)
#Downloading the recording objects
ka.set_config(fr='default_readonly')
print(
    f'Downloading recording: {recordingZero["studyName"]}/{recordingZero["name"]}'
)
ka.load_file(recordingZero['directory'] + '/raw.mda')
ka.load_file(recordingZero['directory'] + '/params.json')
ka.load_file(recordingZero['directory'] + '/geom.csv')
ka.load_file(recordingZero['directory'] + '/firings_true.mda')

#Attaching the results
recordingZero['results'] = dict()

#Tryting to plot the recordings
recordingInput = AutoRecordingExtractor(dict(path=recordingPath),
                                        download=True)
w_ts = sw.plot_timeseries(recordingInput)
w_ts.figure.suptitle("Recording by group")
w_ts.ax.set_ylabel("Channel_ids")

#We will also try to plot the rastor plot for the ground truth
gtOutput = AutoSortingExtractor(sortingPath)
#We need to change the indices of  the ground truth output
w_rs_gt = sw.plot_rasters(gtOutput, sampling_frequency=sampleRate)

#Spike-Sorting
#trying to run SPYKINGCIRCUS through spike interface
#spykingcircus
with ka.config(fr='default_readonly'):
    #with hither.config(cache='default_readwrite'):
    with hither.config(container='default'):
Beispiel #27
0
#Downloading the recording objects

ka.set_config(fr='default_readonly')
print(
    f'Downloading recording: {recordingZero["studyName"]}/{recordingZero["name"]}'
)
ka.load_file(recordingZero['directory'] + '/raw.mda')
ka.load_file(recordingZero['directory'] + '/params.json')
ka.load_file(recordingZero['directory'] + '/geom.csv')
ka.load_file(recordingZero['directory'] + '/firings_true.mda')

#Attaching the results
recordingZero['results'] = dict()

#Tryting to plot the recordings
recordingInput = AutoRecordingExtractor(dict(path=recordingPath),
                                        download=True)
w_ts = sw.plot_timeseries(recordingInput)
w_ts.figure.suptitle("Recording by group")
w_ts.ax.set_ylabel("Channel_ids")

gtOutput = AutoSortingExtractor(sortingPath)
#Only getting a part of the recording
#gtOutput.add_epoch(epoch_name="first_half", start_frame=0, end_frame=recordingInput.get_num_frames()/2) #set

#subsorting = gtOutput.get_epoch("first_half")
#w_rs_gt = sw.plot_rasters(gtOutput,sampling_frequency=sampleRate)

#w_wf_gt = sw.plot_unit_waveforms(recordingInput,gtOutput, max_spikes_per_unit=100)

#We will also try to plot the rastor plot for the ground truth
Beispiel #28
0
def ironclust(recording_path,
              sorting_out,
              detect_threshold=4,
              freq_min=300,
              freq_max=0,
              detect_sign=-1,
              adjacency_radius=50,
              whiten=False,
              adjacency_radius_out=100,
              merge_thresh=0.95,
              fft_thresh=8,
              knn=30,
              min_count=30,
              delta_cut=1,
              pc_per_chan=6,
              batch_sec_drift=600,
              step_sec_drift=20,
              common_ref_type='trimmean',
              fGpu=True,
              clip_pre=0.25,
              clip_post=0.75,
              merge_thresh_cc=1):

    from spikeforest2_utils import AutoRecordingExtractor, AutoSortingExtractor
    from ._ironclustsorter import IronClustSorter

    recording = AutoRecordingExtractor(dict(path=recording_path),
                                       download=True)

    # Sorting
    print('Sorting...')
    sorter = IronClustSorter(recording=recording,
                             output_folder='/tmp/tmp_ironclust_' +
                             _random_string(8),
                             delete_output_folder=True)

    sorter.set_params(fft_thresh_low=0,
                      nSites_whiten=32,
                      feature_type='gpca',
                      post_merge_mode=1,
                      sort_mode=1,
                      prm_template_name='',
                      filter_type='bandpass',
                      filter_detect_type='none',
                      detect_threshold=detect_threshold,
                      freq_min=freq_min,
                      freq_max=freq_max,
                      detect_sign=detect_sign,
                      adjacency_radius=adjacency_radius,
                      whiten=whiten,
                      adjacency_radius_out=adjacency_radius_out,
                      merge_thresh=merge_thresh,
                      fft_thresh=fft_thresh,
                      knn=knn,
                      min_count=min_count,
                      delta_cut=delta_cut,
                      pc_per_chan=pc_per_chan,
                      batch_sec_drift=batch_sec_drift,
                      step_sec_drift=step_sec_drift,
                      common_ref_type=common_ref_type,
                      fGpu=fGpu,
                      clip_pre=clip_pre,
                      clip_post=clip_post,
                      merge_thresh_cc=merge_thresh_cc)
    timer = sorter.run()
    #print('#SF-SORTER-RUNTIME#{:.3f}#'.format(timer))
    sorting = sorter.get_result()

    AutoSortingExtractor.write_sorting(sorting=sorting, save_path=sorting_out)