コード例 #1
0
def main():
    samplerate = 30000
    duration_sec = 10  # number of timepoints
    true_firing_rates_hz = [1, 2, 3, 4, 5]
    approx_false_negative_rates = [0, 0.1, 0.2, 0.3, 0.4]
    approx_false_positive_rates = [0, 0.2, 0.1, 0.4, 0.3]
    extra_unit_firing_rates_hz = [0.5, 1, 1.5]

    num_timepoints = samplerate * duration_sec

    sorting_true = se.NumpySortingExtractor()
    sorting = se.NumpySortingExtractor()
    for ii in range(len(true_firing_rates_hz)):
        num_events = int(duration_sec * true_firing_rates_hz[ii])
        times0 = np.random.choice(np.arange(num_timepoints),
                                  size=num_events,
                                  replace=False).astype(float)
        num_hits = int((1 - approx_false_negative_rates[ii]) * num_events)
        hits = np.random.choice(times0, size=num_hits, replace=False)
        num_extra = int(approx_false_positive_rates[ii] * num_events)
        extra = np.random_choice(np.arange(num_timepoints),
                                 size=num_extra,
                                 replace=False).astype(float)
        times1 = np.sort(hits + extra)

        sorting_true.add_unit(ii + 1, times0)
        sorting.add_unit(ii + 1, times1)

    for ii in range(len(extra_unit_firing_rates_hz)):
        num_events = int(duration_sec * extra_unit_firing_rates_hz[ii])
        times0 = np.random.choice(np.arange(num_timepoints),
                                  size=num_events,
                                  replace=False).astype(float)
        sorting.add_unit(len(true_firing_rates_hz) + ii + 1, times0)
コード例 #2
0
    def _create_example(self, seed):
        channel_ids = [0, 1, 2, 3]
        num_channels = 4
        num_frames = 10000
        sampling_frequency = 30000
        X = np.random.RandomState(seed=seed).normal(0, 1, (num_channels, num_frames))
        geom = np.random.RandomState(seed=seed).normal(0, 1, (num_channels, 2))
        X = (X * 100).astype(int)
        RX = se.NumpyRecordingExtractor(timeseries=X, sampling_frequency=sampling_frequency, geom=geom)
        RX2 = se.NumpyRecordingExtractor(timeseries=X, sampling_frequency=sampling_frequency, geom=geom)
        RX3 = se.NumpyRecordingExtractor(timeseries=X, sampling_frequency=sampling_frequency, geom=geom)
        SX = se.NumpySortingExtractor()
        spike_times = [200, 300, 400]
        train1 = np.sort(np.rint(np.random.RandomState(seed=seed).uniform(0, num_frames, spike_times[0])).astype(int))
        SX.add_unit(unit_id=1, times=train1)
        SX.add_unit(unit_id=2, times=np.sort(np.random.RandomState(seed=seed).uniform(0, num_frames, spike_times[1])))
        SX.add_unit(unit_id=3, times=np.sort(np.random.RandomState(seed=seed).uniform(0, num_frames, spike_times[2])))
        SX.set_unit_property(unit_id=1, property_name='stability', value=80)
        SX.set_sampling_frequency(sampling_frequency)
        SX2 = se.NumpySortingExtractor()
        spike_times2 = [100, 150, 450]
        train2 = np.rint(np.random.RandomState(seed=seed).uniform(0, num_frames, spike_times2[0])).astype(int)
        SX2.add_unit(unit_id=3, times=train2)
        SX2.add_unit(unit_id=4, times=np.random.RandomState(seed=seed).uniform(0, num_frames, spike_times2[1]))
        SX2.add_unit(unit_id=5, times=np.random.RandomState(seed=seed).uniform(0, num_frames, spike_times2[2]))
        SX2.set_unit_property(unit_id=4, property_name='stability', value=80)
        SX2.set_unit_spike_features(unit_id=3, feature_name='widths', value=np.asarray([3] * spike_times2[0]))
        RX.set_channel_locations([0, 0], channel_ids=0)
        for i, unit_id in enumerate(SX2.get_unit_ids()):
            SX2.set_unit_property(unit_id=unit_id, property_name='shared_unit_prop', value=i)
            SX2.set_unit_spike_features(unit_id=unit_id, feature_name='shared_unit_feature',
                                        value=np.asarray([i] * spike_times2[i]))
        for i, channel_id in enumerate(RX.get_channel_ids()):
            RX.set_channel_property(channel_id=channel_id, property_name='shared_channel_prop', value=i)

        SX3 = se.NumpySortingExtractor()
        train3 = np.asarray([1, 20, 21, 35, 38, 45, 46, 47])
        SX3.add_unit(unit_id=0, times=train3)
        features3 = np.asarray([0, 5, 10, 15, 20, 25, 30, 35])
        features4 = np.asarray([0, 10, 20, 30])
        feature4_idx = np.asarray([0, 2, 4, 6])
        SX3.set_unit_spike_features(unit_id=0, feature_name='dummy', value=features3)
        SX3.set_unit_spike_features(unit_id=0, feature_name='dummy2', value=features4, indexes=feature4_idx)

        example_info = dict(
            channel_ids=channel_ids,
            num_channels=num_channels,
            num_frames=num_frames,
            sampling_frequency=sampling_frequency,
            unit_ids=[1, 2, 3],
            train1=train1,
            train2=train2,
            train3=train3,
            features3=features3,
            unit_prop=80,
            channel_prop=(0, 0)
        )

        return (RX, RX2, RX3, SX, SX2, SX3, example_info)
コード例 #3
0
def make_sorting(times1, labels1, times2, labels2):
    gt_sorting = se.NumpySortingExtractor()
    tested_sorting = se.NumpySortingExtractor()
    gt_sorting.set_times_labels(np.array(times1), np.array(labels1))
    tested_sorting.set_times_labels(np.array(times2), np.array(labels2))
    gt_sorting.set_sampling_frequency(30000)
    tested_sorting.set_sampling_frequency(30000)
    return gt_sorting, tested_sorting
コード例 #4
0
def make_sorting(times1, labels1, times2, labels2, times3, labels3):
    sorting1 = se.NumpySortingExtractor()
    sorting2 = se.NumpySortingExtractor()
    sorting3 = se.NumpySortingExtractor()
    sorting1.set_times_labels(np.array(times1), np.array(labels1))
    sorting2.set_times_labels(np.array(times2), np.array(labels2))
    sorting3.set_times_labels(np.array(times3), np.array(labels3))
    return sorting1, sorting2, sorting3
コード例 #5
0
    def __init__(self, arg, samplerate=None):
        super().__init__()
        if (isinstance(arg, dict)) and ('sorting_format' in arg):
            obj = dict(arg)
        else:
            obj = _create_object_for_arg(arg, samplerate=samplerate)
            assert obj is not None, f'Unable to create sorting from arg: {arg}'
        self._object: dict = obj

        if 'firings' in self._object:
            sorting_format = 'mda'
            data={'firings': self._object['firings'], 'samplerate': self._object.get('samplerate', 30000)}
        else:
            sorting_format = self._object['sorting_format']
            data: dict = self._object['data']
        if sorting_format == 'mda':
            firings_path = kp.load_file(data['firings'])
            assert firings_path is not None, f'Unable to load firings file: {data["firings"]}'
            self._sorting: se.SortingExtractor = MdaSortingExtractor(firings_file=firings_path, samplerate=data['samplerate'])
        elif sorting_format == 'h5_v1':
            h5_path = kp.load_file(data['h5_path'])
            self._sorting = H5SortingExtractorV1(h5_path=h5_path)
        elif sorting_format == 'npy1':
            times_npy = kp.load_npy(data['times_npy_uri'])
            labels_npy = kp.load_npy(data['labels_npy_uri'])
            samplerate = data['samplerate']
            S = se.NumpySortingExtractor()
            S.set_sampling_frequency(samplerate)
            S.set_times_labels(times_npy.ravel(), labels_npy.ravel())
            self._sorting = S
        elif sorting_format == 'snippets1':
            S = Snippets1SortingExtractor(snippets_h5_uri = data['snippets_h5_uri'], p2p=True)
            self._sorting = S
        elif sorting_format == 'npy2':
            npz = kp.load_npy(data['npz_uri'])
            times_npy = npz['spike_indexes']
            labels_npy = npz['spike_labels']
            samplerate = float(npz['sampling_frequency'])
            S = se.NumpySortingExtractor()
            S.set_sampling_frequency(samplerate)
            S.set_times_labels(times_npy.ravel(), labels_npy.ravel())
            self._sorting = S
        elif sorting_format == 'nwb':
            from .nwbextractors import NwbSortingExtractor
            path0 = kp.load_file(data['path'])
            self._sorting: se.SortingExtractor = NwbSortingExtractor(path0)
        elif sorting_format == 'in_memory':
            S = get_in_memory_object(data)
            if S is None:
                raise Exception('Unable to find in-memory object for sorting')
            self._sorting = S
        else:
            raise Exception(f'Unexpected sorting format: {sorting_format}')

        self.copy_unit_properties(sorting=self._sorting)
コード例 #6
0
    def _create_example(self):
        channel_ids = [0, 1, 2, 3]
        num_channels = 4
        num_frames = 10000
        samplerate = 30000
        X = np.random.normal(0, 1, (num_channels, num_frames))
        geom = np.random.normal(0, 1, (num_channels, 2))
        X = (X * 100).astype(int)
        RX = se.NumpyRecordingExtractor(timeseries=X,
                                        samplerate=samplerate,
                                        geom=geom)
        RX2 = se.NumpyRecordingExtractor(timeseries=X,
                                         samplerate=samplerate,
                                         geom=geom)
        RX3 = se.NumpyRecordingExtractor(timeseries=X,
                                         samplerate=samplerate,
                                         geom=geom)
        SX = se.NumpySortingExtractor()
        spike_times = [200, 300, 400]
        train1 = np.sort(
            np.rint(np.random.uniform(0, num_frames,
                                      spike_times[0])).astype(int))
        SX.add_unit(unit_id=1, times=train1)
        SX.add_unit(unit_id=2,
                    times=np.sort(
                        np.random.uniform(0, num_frames, spike_times[1])))
        SX.add_unit(unit_id=3,
                    times=np.sort(
                        np.random.uniform(0, num_frames, spike_times[2])))
        SX.set_unit_property(unit_id=1, property_name='stablility', value=80)
        SX.set_sampling_frequency(samplerate)
        SX2 = se.NumpySortingExtractor()
        spike_times2 = [100, 150, 450]
        train2 = np.rint(np.random.uniform(0, num_frames,
                                           spike_times[0])).astype(int)
        SX2.add_unit(unit_id=3, times=train2)
        SX2.add_unit(unit_id=4,
                     times=np.random.uniform(0, num_frames, spike_times2[1]))
        SX2.add_unit(unit_id=5,
                     times=np.random.uniform(0, num_frames, spike_times2[2]))
        SX2.set_unit_property(unit_id=4, property_name='stablility', value=80)
        RX.set_channel_property(channel_id=0,
                                property_name='location',
                                value=(0, 0))
        example_info = dict(channel_ids=channel_ids,
                            num_channels=num_channels,
                            num_frames=num_frames,
                            samplerate=samplerate,
                            unit_ids=[1, 2, 3],
                            train1=train1,
                            unit_prop=80,
                            channel_prop=(0, 0))

        return (RX, RX2, RX3, SX, SX2, example_info)
コード例 #7
0
ファイル: mountainsort4.py プロジェクト: alejoe91/ml_ms4alg
def mountainsort4(*,recording,detect_sign,clip_size=50,adjacency_radius=-1,detect_threshold=3,detect_interval=10,num_workers=None):
  if num_workers is None:
    num_workers=int((multiprocessing.cpu_count()+1)/2)

  print('Using {} workers.'.format(num_workers))

  MS4=MountainSort4()
  MS4.setRecording(recording)
  geom=_get_geom_from_recording(recording)
  MS4.setGeom(geom)
  MS4.setSortingOpts(
    clip_size=clip_size,
    adjacency_radius=adjacency_radius,
    detect_sign=detect_sign,
    detect_interval=detect_interval,
    detect_threshold=detect_threshold
  )
  tmpdir = tempfile.mkdtemp()
  MS4.setNumWorkers(num_workers)
  print('Using tmpdir: '+tmpdir)
  MS4.setTemporaryDirectory(tmpdir)
  try:
    MS4.sort()
  except:
    print('Cleaning tmpdir:: '+tmpdir)
    shutil.rmtree(tmpdir)
    raise
  print('Cleaning tmpdir::::: '+tmpdir)
  shutil.rmtree(tmpdir)
  times,labels,channels=MS4.eventTimesLabelsChannels()
  output=se.NumpySortingExtractor()
  output.set_times_labels(times=times,labels=labels)
  return output
コード例 #8
0
 def setUp(self):
     M = 4
     N = 10000
     N_ttl = 50
     seed = 0
     sampling_frequency = 30000
     X = np.random.RandomState(seed=seed).normal(0, 1, (M, N))
     geom = np.random.RandomState(seed=seed).normal(0, 1, (M, 2))
     self._X = X
     self._geom = geom
     self._sampling_frequency = sampling_frequency
     self.RX = se.NumpyRecordingExtractor(
         timeseries=X, sampling_frequency=sampling_frequency, geom=geom)
     self._ttl_frames = np.sort(np.random.permutation(N)[:N_ttl])
     self.RX.set_ttls(self._ttl_frames)
     self.SX = se.NumpySortingExtractor()
     L = 200
     self._train1 = np.rint(
         np.random.RandomState(seed=seed).uniform(0, N, L)).astype(int)
     self.SX.add_unit(unit_id=1, times=self._train1)
     self.SX.add_unit(unit_id=2,
                      times=np.random.RandomState(seed=seed).uniform(
                          0, N, L))
     self.SX.add_unit(unit_id=3,
                      times=np.random.RandomState(seed=seed).uniform(
                          0, N, L))
コード例 #9
0
ファイル: toy_example.py プロジェクト: yger/spikeextractors
def toy_example(duration=10,
                num_channels=4,
                samplerate=30000.0,
                K=10,
                seed=None):
    upsamplefac = 13

    waveforms, geom = synthesize_random_waveforms(K=K,
                                                  M=num_channels,
                                                  average_peak_amplitude=-100,
                                                  upsamplefac=upsamplefac,
                                                  seed=seed)
    times, labels = synthesize_random_firings(K=K,
                                              duration=duration,
                                              samplerate=samplerate,
                                              seed=seed)
    labels = labels.astype(np.int64)
    SX = se.NumpySortingExtractor()
    SX.set_times_labels(times, labels)
    X = synthesize_timeseries(sorting=SX,
                              waveforms=waveforms,
                              noise_level=10,
                              samplerate=samplerate,
                              duration=duration,
                              waveform_upsamplefac=upsamplefac)
    SX.set_sampling_frequency(samplerate)

    RX = se.NumpyRecordingExtractor(timeseries=X,
                                    samplerate=samplerate,
                                    geom=geom)
    return (RX, SX)
コード例 #10
0
def get_unmatched_sorting(sx1, sx2, ids1, ids2):
    ret = se.NumpySortingExtractor()
    for ii in range(len(ids1)):
        id1 = ids1[ii]
        id2 = ids2[ii]
        train1 = sx1.get_unit_spike_train(unit_id=id1)
        train2 = sx2.get_unit_spike_train(unit_id=id2)
        train = get_unmatched_times(train1, train2, delta=100)
        ret.add_unit(id1, train)
    return ret
コード例 #11
0
def get_unmatched_sorting(sx1, sx2, ids1, ids2):
    # spikes in first sorting that are not matched to spikes in second sorting
    ret = se.NumpySortingExtractor()
    for ii in range(len(ids1)):
        id1 = ids1[ii]
        id2 = ids2[ii]
        train1 = sx1.get_unit_spike_train(unit_id=id1)
        train2 = sx2.get_unit_spike_train(unit_id=id2)
        train = get_unmatched_times(train1, train2, delta=100)
        ret.addUnit(id1, train)
    return ret
コード例 #12
0
    def __init__(self, arg, samplerate=None):
        super().__init__()
        if (isinstance(arg, dict)) and ('sorting_format' in arg):
            obj = dict(arg)
        else:
            obj = _create_object_for_arg(arg, samplerate=samplerate)
            assert obj is not None, f'Unable to create sorting from arg: {arg}'
        self._object: dict = obj

        sorting_format = self._object['sorting_format']
        data: dict = self._object['data']
        if sorting_format == 'mda':
            firings_path = kp.load_file(data['firings'])
            assert firings_path is not None, f'Unable to load firings file: {data["firings"]}'
            self._sorting: se.SortingExtractor = MdaSortingExtractor(
                firings_file=firings_path, samplerate=data['samplerate'])
        elif sorting_format == 'h5_v1':
            h5_path = kp.load_file(data['h5_path'])
            self._sorting = H5SortingExtractorV1(h5_path=h5_path)
        elif sorting_format == 'npy1':
            times_npy = kp.load_npy(data['times_npy_uri'])
            labels_npy = kp.load_npy(data['labels_npy_uri'])
            samplerate = data['samplerate']
            S = se.NumpySortingExtractor()
            S.set_sampling_frequency(samplerate)
            S.set_times_labels(times_npy.ravel(), labels_npy.ravel())
            self._sorting = S
        elif sorting_format == 'npy2':
            npz = kp.load_npy(data['npz_uri'])
            times_npy = npz['spike_indexes']
            labels_npy = npz['spike_labels']
            samplerate = float(npz['sampling_frequency'])
            S = se.NumpySortingExtractor()
            S.set_sampling_frequency(samplerate)
            S.set_times_labels(times_npy.ravel(), labels_npy.ravel())
            self._sorting = S
        else:
            raise Exception(f'Unexpected sorting format: {sorting_format}')

        self.copy_unit_properties(sorting=self._sorting)
コード例 #13
0
    def test_npz_extractor(self):
        path = self.test_dir + '/sorting.npz'
        se.NpzSortingExtractor.write_sorting(self.SX, path)
        SX_npz = se.NpzSortingExtractor(path)

        # empty write
        sorting_empty = se.NumpySortingExtractor()
        path_empty = self.test_dir + '/sorting_empty.npz'
        se.NpzSortingExtractor.write_sorting(sorting_empty, path_empty)

        check_sorting_return_types(SX_npz)
        check_sortings_equal(self.SX, SX_npz)
        check_dumping(SX_npz)
コード例 #14
0
ファイル: ironclust.py プロジェクト: yger/spiketoolkit
    def get_result_from_folder(output_folder):

        # overwrite the SorterBase.get_result
        from mountainlab_pytools import mdaio

        result_fname = Path(output_folder) / 'firings.mda'

        assert result_fname.exists(), 'Result file does not exist: {}'.format(
            str(result_fname))

        firings = mdaio.readmda(str(result_fname))
        sorting = se.NumpySortingExtractor()
        sorting.set_times_labels(firings[1, :], firings[2, :])
        return sorting
コード例 #15
0
def toy_example1(duration=10, num_channels=4, samplerate=30000, K=10, firing_rates=None, noise_level=10):
    upsamplefac = 13

    waveforms, geom = synthesize_random_waveforms(K=K, M=num_channels, average_peak_amplitude=-100,
                                                  upsamplefac=upsamplefac)
    times, labels = synthesize_random_firings(K=K, duration=duration, samplerate=samplerate, firing_rates=firing_rates)
    labels = labels.astype(np.int64)
    OX = se.NumpySortingExtractor()
    OX.set_times_labels(times, labels)
    X = synthesize_timeseries(sorting=OX, waveforms=waveforms, noise_level=noise_level, samplerate=samplerate, duration=duration,
                              waveform_upsamplefac=upsamplefac)

    IX = se.NumpyRecordingExtractor(timeseries=X, samplerate=samplerate, geom=geom)
    return (IX, OX)
コード例 #16
0
def setup_study():
    rec_names = [
        '20160415_patch2',
        '20160426_patch2',
        '20160426_patch3',
        '20170621_patch1',
        '20170713_patch1',
        '20170725_patch1',
        '20170728_patch2',
        '20170803_patch1',
    ]

    gt_dict = {}
    for rec_name in rec_names:

        # find raw file
        dirname = recording_folder + rec_name + '/'
        for f in os.listdir(dirname):
            if f.endswith('.raw') and not f.endswith('juxta.raw'):
                mea_filename = dirname + f

        # raw files have an internal offset that depend on the channel count
        # a simple built header can be parsed to get it
        with open(mea_filename.replace('.raw', '.txt'), mode='r') as f:
            offset = int(re.findall('padding = (\d+)', f.read())[0])

        # recording
        rec = se.BinDatRecordingExtractor(mea_filename,
                                          20000.,
                                          256,
                                          'uint16',
                                          offset=offset,
                                          frames_first=True)

        # this reduce channel count to 252
        rec = se.load_probe_file(rec, basedir + 'mea_256.prb')

        # gt sorting
        gt_indexes = np.fromfile(ground_truth_folder + rec_name +
                                 '/juxta_peak_indexes.raw',
                                 dtype='int64')
        sorting_gt = se.NumpySortingExtractor()
        sorting_gt.set_times_labels(gt_indexes,
                                    np.zeros(gt_indexes.size, dtype='int64'))
        sorting_gt.set_sampling_frequency(20000.0)

        gt_dict[rec_name] = (rec, sorting_gt)

    study = GroundTruthStudy.setup(study_folder, gt_dict)
コード例 #17
0
 def setUp(self):
     M = 4
     N = 10000
     samplerate = 30000
     X = np.random.normal(0, 1, (M, N))
     geom = np.random.normal(0, 1, (M, 2))
     self._X = X
     self._geom = geom
     self._samplerate = samplerate
     self.RX = se.NumpyRecordingExtractor(timeseries=X, samplerate=samplerate, geom=geom)
     self.SX = se.NumpySortingExtractor()
     L = 200
     self._train1 = np.rint(np.random.uniform(0, N, L)).astype(int)
     self.SX.add_unit(unit_id=1, times=self._train1)
     self.SX.add_unit(unit_id=2, times=np.random.uniform(0, N, L))
     self.SX.add_unit(unit_id=3, times=np.random.uniform(0, N, L))
コード例 #18
0
def toy_example(duration=10,
                num_channels=4,
                sampling_frequency=30000.0,
                K=10,
                dumpable=False,
                dump_folder=None,
                seed=None):
    upsamplefac = 13

    waveforms, geom = synthesize_random_waveforms(K=K,
                                                  M=num_channels,
                                                  average_peak_amplitude=-100,
                                                  upsamplefac=upsamplefac,
                                                  seed=seed)
    times, labels = synthesize_random_firings(
        K=K,
        duration=duration,
        sampling_frequency=sampling_frequency,
        seed=seed)
    labels = labels.astype(np.int64)
    SX = se.NumpySortingExtractor()
    SX.set_times_labels(times, labels)
    X = synthesize_timeseries(sorting=SX,
                              waveforms=waveforms,
                              noise_level=10,
                              sampling_frequency=sampling_frequency,
                              duration=duration,
                              waveform_upsamplefac=upsamplefac,
                              seed=seed)
    SX.set_sampling_frequency(sampling_frequency)

    RX = se.NumpyRecordingExtractor(timeseries=X,
                                    sampling_frequency=sampling_frequency,
                                    geom=geom)
    RX.is_filtered = True

    if dumpable:
        if dump_folder is None:
            dump_folder = 'toy_example'
        dump_folder = Path(dump_folder)

        se.MdaRecordingExtractor.write_recording(RX, dump_folder)
        RX = se.MdaRecordingExtractor(dump_folder)
        se.NpzSortingExtractor.write_sorting(SX, dump_folder / 'sorting.npz')
        SX = se.NpzSortingExtractor(dump_folder / 'sorting.npz')

    return RX, SX
コード例 #19
0
def gen_synth_datasets(datasets,
                       *,
                       outdir,
                       num_channels=4,
                       upsamplefac=13,
                       samplerate=30000,
                       average_peak_amplitude=-100):
    if not os.path.exists(outdir):
        os.mkdir(outdir)
    for ds in datasets:
        ds_name = ds['name']
        print(ds_name)
        K = ds['K']
        duration = ds['duration']
        noise_level = ds['noise_level']
        waveforms, geom = synthesize_random_waveforms(
            K=K,
            M=num_channels,
            average_peak_amplitude=average_peak_amplitude,
            upsamplefac=upsamplefac)
        times, labels = synthesize_random_firings(K=K,
                                                  duration=duration,
                                                  samplerate=samplerate)
        labels = labels.astype(np.int64)
        OX = si.NumpySortingExtractor()
        OX.setTimesLabels(times, labels)
        X = synthesize_timeseries(output_extractor=OX,
                                  waveforms=waveforms,
                                  noise_level=noise_level,
                                  samplerate=samplerate,
                                  duration=duration,
                                  waveform_upsamplefac=upsamplefac)
        IX = si.NumpyRecordingExtractor(timeseries=X,
                                        samplerate=samplerate,
                                        geom=geom)
        si.MdaRecordingExtractor.writeRecording(IX,
                                                outdir + '/{}'.format(ds_name))
        si.MdaSortingExtractor.writeSorting(
            OX, outdir + '/{}/firings_true.mda'.format(ds_name))
        # write json with two fields

    print('Done.')
コード例 #20
0
    def get_sorting_extractor(self, key, sort_interval):
        #TODO: replace with spikeinterface call if possible
        """Generates a numpy sorting extractor given a key that retrieves a SpikeSorting and a specified sort interval

        :param key: key for a single SpikeSorting
        :type key: dict
        :param sort_interval: [start_time, end_time]
        :type sort_interval: numpy array
        :return: a spikeextractors sorting extractor with the sorting information
        """
        # get the units object from the NWB file that the data are stored in.
        units = (SpikeSorting & key).fetch_nwb()[0]['units'].to_dataframe()
        unit_timestamps = []
        unit_labels = []
        
        # TODO: do something more efficient here; note that searching for maching sort_intervals within pandas doesn't seem to work
        for index, unit in units.iterrows():
            if np.ndarray.all(np.ravel(unit['sort_interval']) == sort_interval):
                unit_timestamps.extend(unit['spike_times'])
                unit_labels.extend([index]*len(unit['spike_times']))

        output=se.NumpySortingExtractor()
        output.set_times_labels(times=np.asarray(unit_timestamps),labels=np.asarray(unit_labels))
        return output
コード例 #21
0
def detect_spikes(recording,
                  channel_ids=None,
                  detect_threshold=5,
                  n_pad_ms=2,
                  upsample=1,
                  detect_sign=-1,
                  min_diff_samples=5,
                  parallel=False,
                  n_jobs=-1):
    '''
    Detects spikes per channel.
    Parameters
    ----------
    recording: RecordingExtractor
        The recording extractor object
    channel_ids: list or None
        List of channels to perform detection. If None all channels are used
    detect_threshold: float
        Threshold in MAD to detect peaks
    n_pad_ms: float
        Time in ms to find absolute peak around detected peak
    upsample: int
        The detected waveforms are upsampled 'upsample' times (default=1)
    detect_sign: int
        Sign of the detection: -1 (negative), 1 (positive), 0 (both)
    min_diff_samples: int
        Minimum interval to skip consecutive spikes (default=5)
    parallel: bool
        If True, each channel is run in parallel
    n_jobs: int
        Number of jobs when parallel
    Returns
    -------
    sorting_detected: SortingExtractor
        The sorting extractor object with the detected spikes. Unit ids are the same as channel ids and units have the
        'channel' property to specify which channel they correspond to
    '''
    spike_times = []
    labels = []
    n_pad_samples = int(n_pad_ms * recording.get_sampling_frequency() / 1000)

    if channel_ids is None:
        channel_ids = recording.get_channel_ids()
    else:
        assert np.all([ch in recording.get_channel_ids() for ch in channel_ids]), "Not all 'channel_ids' are in the" \
                                                                                  "recording."

    if parallel:
        output = Parallel(n_jobs=n_jobs)(
            delayed(_detect_and_align_peaks_single_channel)(
                recording, ch, detect_threshold, detect_sign, n_pad_samples,
                upsample, min_diff_samples) for ch in channel_ids)
        for o in output:
            spike_times.append(o[0])
            labels.append(o[1])
    else:
        for ch in channel_ids:
            peak_times, label = _detect_and_align_peaks_single_channel(
                recording, ch, detect_threshold, detect_sign, n_pad_samples,
                upsample, min_diff_samples)
            spike_times.append(peak_times)
            labels.append(label)

    # create sorting extractor
    sorting = se.NumpySortingExtractor()
    labels_flat = np.array(list(itertools.chain(*labels)))
    times_flat = np.array(list(itertools.chain(*spike_times)))
    sorting.set_times_labels(times=times_flat, labels=labels_flat)

    for u in sorting.get_unit_ids():
        sorting.set_unit_property(u, 'channel', u)

    return sorting
コード例 #22
0
ファイル: waveclus.py プロジェクト: samuelgarcia/spikeforest
def waveclus_helper(
        *,
        recording,  # Recording object
        tmpdir,  # Temporary working directory
        params=dict(),
        **kwargs):

    waveclus_path = os.environ.get('WAVECLUS_PATH_DEV', None)
    if waveclus_path:
        print('Using waveclus from WAVECLUS_PATH_DEV directory: {}'.format(
            waveclus_path))
    else:
        try:
            print('Auto-installing waveclus.')
            waveclus_path = install_waveclus(
                repo='https://github.com/csn-le/wave_clus.git',
                commit='248d15c7eaa2b45b15e4488dfb9b09bfe39f5341')
        except:
            traceback.print_exc()
            raise Exception(
                'Problem installing waveclus. You can set the WAVECLUS_PATH_DEV to force to use a particular path.'
            )
    print('Using waveclus from: {}'.format(waveclus_path))

    dataset_dir = os.path.join(tmpdir, 'waveclus_dataset')
    # Generate three files in the dataset directory: raw.mda, geom.csv, params.json
    SFMdaRecordingExtractor.write_recording(recording=recording,
                                            save_path=dataset_dir,
                                            params=params,
                                            _preserve_dtype=True)

    samplerate = recording.get_sampling_frequency()

    print('Reading timeseries header...')
    raw_mda = os.path.join(dataset_dir, 'raw.mda')
    HH = mdaio.readmda_header(raw_mda)
    num_channels = HH.dims[0]
    num_timepoints = HH.dims[1]
    duration_minutes = num_timepoints / samplerate / 60
    print('Num. channels = {}, Num. timepoints = {}, duration = {} minutes'.
          format(num_channels, num_timepoints, duration_minutes))

    # new method
    source_path = os.path.dirname(os.path.realpath(__file__))
    print('Running waveclus in {tmpdir}...'.format(tmpdir=tmpdir))
    cmd = '''
        addpath(genpath('{waveclus_path}'), '{source_path}', '{source_path}/mdaio');
        try
            p_waveclus('{tmpdir}', '{dataset_dir}/raw.mda', '{tmpdir}/firings.mda', {samplerate});
        catch
            fprintf('----------------------------------------');
            fprintf(lasterr());
            quit(1);
        end
        quit(0);
    '''
    cmd = cmd.format(waveclus_path=waveclus_path,
                     tmpdir=tmpdir,
                     dataset_dir=dataset_dir,
                     source_path=source_path,
                     samplerate=samplerate)

    matlab_cmd = mlpr.ShellScript(cmd,
                                  script_path=tmpdir + '/run_waveclus.m',
                                  keep_temp_files=True)
    matlab_cmd.write()

    shell_cmd = '''
        #!/bin/bash
        cd {tmpdir}
        matlab -nosplash -nodisplay -r run_waveclus
    '''.format(tmpdir=tmpdir)
    shell_cmd = mlpr.ShellScript(shell_cmd,
                                 script_path=tmpdir + '/run_waveclus.sh',
                                 keep_temp_files=True)
    shell_cmd.write(tmpdir + '/run_waveclus.sh')
    time_ = time.time()
    shell_cmd.start()

    retcode = shell_cmd.wait()
    print('#SF-SORTER-RUNTIME#{:.3f}#'.format(time_ - time.time()))

    if retcode != 0:
        raise Exception('waveclus returned a non-zero exit code')

    # parse output
    result_fname = tmpdir + '/firings.mda'
    if not os.path.exists(result_fname):
        raise Exception('Result file does not exist: ' + result_fname)

    firings = mdaio.readmda(result_fname)
    sorting = se.NumpySortingExtractor()
    sorting.set_times_labels(firings[1, :], firings[2, :])
    return sorting
コード例 #23
0
def ironclust_helper(
        *,
        recording,  # Recording object
        tmpdir,  # Temporary working directory
        params=dict(),
        ironclust_path,
        **kwargs):
    source_dir = os.path.dirname(os.path.realpath(__file__))

    dataset_dir = tmpdir + '/ironclust_dataset'
    # Generate three files in the dataset directory: raw.mda, geom.csv, params.json
    SFMdaRecordingExtractor.write_recording(
        recording=recording, save_path=dataset_dir, params=params, _preserve_dtype=True)

    samplerate = recording.get_sampling_frequency()

    print('Reading timeseries header...')
    HH = mdaio.readmda_header(dataset_dir + '/raw.mda')
    num_channels = HH.dims[0]
    num_timepoints = HH.dims[1]
    duration_minutes = num_timepoints / samplerate / 60
    print('Num. channels = {}, Num. timepoints = {}, duration = {} minutes'.format(
        num_channels, num_timepoints, duration_minutes))

    print('Creating argfile.txt...')
    txt = ''
    for key0, val0 in kwargs.items():
        txt += '{}={}\n'.format(key0, val0)
    txt += 'samplerate={}\n'.format(samplerate)
    if 'scale_factor' in params:
        txt += 'scale_factor={}\n'.format(params["scale_factor"])
    _write_text_file(dataset_dir + '/argfile.txt', txt)

    # new method
    print('Running ironclust in {tmpdir}...'.format(tmpdir=tmpdir))
    cmd = '''
        addpath('{source_dir}');
        addpath('{ironclust_path}', '{ironclust_path}/matlab', '{ironclust_path}/matlab/mdaio');
        try
            p_ironclust('{tmpdir}', '{dataset_dir}/raw.mda', '{dataset_dir}/geom.csv', '', '', '{tmpdir}/firings.mda', '{dataset_dir}/argfile.txt');
        catch
            fprintf('----------------------------------------');
            fprintf(lasterr());
            quit(1);
        end
        quit(0);
    '''
    cmd = cmd.format(ironclust_path=ironclust_path, tmpdir=tmpdir, dataset_dir=dataset_dir, source_dir=source_dir)

    matlab_cmd = mlpr.ShellScript(cmd, script_path=tmpdir + '/run_ironclust.m', keep_temp_files=True)
    matlab_cmd.write()

    shell_cmd = '''
        #!/bin/bash
        cd {tmpdir}
        matlab -nosplash -nodisplay -r run_ironclust
    '''.format(tmpdir=tmpdir)
    shell_cmd = mlpr.ShellScript(shell_cmd, script_path=tmpdir + '/run_ironclust.sh', keep_temp_files=True)
    shell_cmd.write(tmpdir + '/run_ironclust.sh')
    shell_cmd.start()

    retcode = shell_cmd.wait()

    if retcode != 0:
        raise Exception('ironclust returned a non-zero exit code')

    # parse output
    result_fname = tmpdir + '/firings.mda'
    if not os.path.exists(result_fname):
        raise Exception('Result file does not exist: ' + result_fname)

    firings = mdaio.readmda(result_fname)
    sorting = se.NumpySortingExtractor()
    sorting.set_times_labels(firings[1, :], firings[2, :])
    return sorting
コード例 #24
0
def jrclust_helper(
        *,
        recording,  # Recording object
        tmpdir,  # Temporary working directory
        params=dict(),
        **kwargs):

    jrclust_path = os.environ.get('JRCLUST_PATH_DEV', None)
    if jrclust_path:
        print('Using jrclust from JRCLUST_PATH_DEV directory: {}'.format(
            jrclust_path))
    else:
        try:
            print('Auto-installing jrclust.')
            jrclust_path = install_jrclust(
                repo='https://github.com/JaneliaSciComp/JRCLUST.git',
                commit='3d2e75c0041dca2a9f273598750c6a14dbc4c1b8')
        except:
            traceback.print_exc()
            raise Exception(
                'Problem installing jrclust. You can set the JRCLUST_PATH_DEV to force to use a particular path.'
            )
    print('Using jrclust from: {}'.format(jrclust_path))

    dataset_dir = os.path.join(tmpdir, 'jrclust_dataset')
    # Generate three files in the dataset directory: raw.mda, geom.csv, params.json
    SFMdaRecordingExtractor.write_recording(recording=recording,
                                            save_path=dataset_dir,
                                            params=params,
                                            _preserve_dtype=True)

    samplerate = recording.get_sampling_frequency()

    print('Reading timeseries header...')
    raw_mda = os.path.join(dataset_dir, 'raw.mda')
    HH = mdaio.readmda_header(raw_mda)
    num_channels = HH.dims[0]
    num_timepoints = HH.dims[1]
    duration_minutes = num_timepoints / samplerate / 60
    print('Num. channels = {}, Num. timepoints = {}, duration = {} minutes'.
          format(num_channels, num_timepoints, duration_minutes))

    print('Creating argfile.txt...')
    txt = ''
    for key0, val0 in kwargs.items():
        txt += '{}={}\n'.format(key0, val0)
    if 'scale_factor' in params:
        txt += 'bitScaling={}\n'.format(params["scale_factor"])
    txt += 'sampleRate={}\n'.format(samplerate)
    _write_text_file(dataset_dir + '/argfile.txt', txt)

    # new method
    source_path = os.path.dirname(os.path.realpath(__file__))
    print('Running jrclust in {tmpdir}...'.format(tmpdir=tmpdir))
    cmd = '''
        addpath('{jrclust_path}', '{source_path}', '{source_path}/mdaio');
        try
            p_jrclust('{tmpdir}', '{dataset_dir}/raw.mda', '{dataset_dir}/geom.csv', '{tmpdir}/firings.mda', '{dataset_dir}/argfile.txt');
        catch
            fprintf('----------------------------------------');
            fprintf(lasterr());
            quit(1);
        end
        quit(0);
    '''
    cmd = cmd.format(jrclust_path=jrclust_path,
                     tmpdir=tmpdir,
                     dataset_dir=dataset_dir,
                     source_path=source_path)

    matlab_cmd = mlpr.ShellScript(cmd,
                                  script_path=tmpdir + '/run_jrclust.m',
                                  keep_temp_files=True)
    matlab_cmd.write()

    shell_cmd = '''
        #!/bin/bash
        cd {tmpdir}
        matlab -nosplash -nodisplay -r run_jrclust
    '''.format(tmpdir=tmpdir)
    shell_cmd = mlpr.ShellScript(shell_cmd,
                                 script_path=tmpdir + '/run_jrclust.sh',
                                 keep_temp_files=True)
    shell_cmd.write(tmpdir + '/run_jrclust.sh')
    time_ = time.time()
    shell_cmd.start()

    retcode = shell_cmd.wait()
    print('#SF-SORTER-RUNTIME#{:.3f}#'.format(time_ - time.time()))

    if retcode != 0:
        raise Exception('jrclust returned a non-zero exit code')

    # parse output
    result_fname = tmpdir + '/firings.mda'
    if not os.path.exists(result_fname):
        raise Exception('Result file does not exist: ' + result_fname)

    firings = mdaio.readmda(result_fname)
    sorting = se.NumpySortingExtractor()
    sorting.set_times_labels(firings[1, :], firings[2, :])
    return sorting
コード例 #25
0
    def _create_example(self):
        channel_ids = [0, 1, 2, 3]
        num_channels = 4
        num_frames = 10000
        sampling_frequency = 30000
        X = np.random.normal(0, 1, (num_channels, num_frames))
        geom = np.random.normal(0, 1, (num_channels, 2))
        X = (X * 100).astype(int)
        RX = se.NumpyRecordingExtractor(timeseries=X,
                                        sampling_frequency=sampling_frequency,
                                        geom=geom)
        RX2 = se.NumpyRecordingExtractor(timeseries=X,
                                         sampling_frequency=sampling_frequency,
                                         geom=geom)
        RX3 = se.NumpyRecordingExtractor(timeseries=X,
                                         sampling_frequency=sampling_frequency,
                                         geom=geom)
        SX = se.NumpySortingExtractor()
        spike_times = [200, 300, 400]
        train1 = np.sort(
            np.rint(np.random.uniform(0, num_frames,
                                      spike_times[0])).astype(int))
        SX.add_unit(unit_id=1, times=train1)
        SX.add_unit(unit_id=2,
                    times=np.sort(
                        np.random.uniform(0, num_frames, spike_times[1])))
        SX.add_unit(unit_id=3,
                    times=np.sort(
                        np.random.uniform(0, num_frames, spike_times[2])))
        SX.set_unit_property(unit_id=1, property_name='stability', value=80)
        SX.set_sampling_frequency(sampling_frequency)
        SX2 = se.NumpySortingExtractor()
        spike_times2 = [100, 150, 450]
        train2 = np.rint(np.random.uniform(0, num_frames,
                                           spike_times2[0])).astype(int)
        SX2.add_unit(unit_id=3, times=train2)
        SX2.add_unit(unit_id=4,
                     times=np.random.uniform(0, num_frames, spike_times2[1]))
        SX2.add_unit(unit_id=5,
                     times=np.random.uniform(0, num_frames, spike_times2[2]))
        SX2.set_unit_property(unit_id=4, property_name='stability', value=80)
        SX2.set_unit_spike_features(unit_id=3,
                                    feature_name='widths',
                                    value=np.asarray([3] * spike_times2[0]))
        RX.set_channel_property(channel_id=0,
                                property_name='location',
                                value=(0, 0))
        for i, unit_id in enumerate(SX2.get_unit_ids()):
            SX2.set_unit_property(unit_id=unit_id,
                                  property_name='shared_unit_prop',
                                  value=i)
            SX2.set_unit_spike_features(unit_id=unit_id,
                                        feature_name='shared_unit_feature',
                                        value=np.asarray([i] *
                                                         spike_times2[i]))
        for i, channel_id in enumerate(RX.get_channel_ids()):
            RX.set_channel_property(channel_id=channel_id,
                                    property_name='shared_channel_prop',
                                    value=i)
        example_info = dict(channel_ids=channel_ids,
                            num_channels=num_channels,
                            num_frames=num_frames,
                            sampling_frequency=sampling_frequency,
                            unit_ids=[1, 2, 3],
                            train1=train1,
                            unit_prop=80,
                            channel_prop=(0, 0))

        return (RX, RX2, RX3, SX, SX2, example_info)
コード例 #26
0
def ironclust(*,
    recording, # Recording object
    tmpdir, # Temporary working directory
    detect_sign=-1, # Polarity of the spikes, -1, 0, or 1
    adjacency_radius=-1, # Channel neighborhood adjacency radius corresponding to geom file
    detect_threshold=5, # Threshold for detection
    merge_thresh=.98, # Cluster merging threhold 0..1
    freq_min=300, # Lower frequency limit for band-pass filter
    freq_max=6000, # Upper frequency limit for band-pass filter
    pc_per_chan=3, # Number of pc per channel
    prm_template_name, # Name of the template file
    ironclust_src=None
):      
    if ironclust_src is None:
        ironclust_src=os.getenv('IRONCLUST_SRC',None)
    if not ironclust_src:
        raise Exception('You must either set the IRONCLUST_SRC environment variable, or pass the ironclust_src parameter')
    source_dir=os.path.dirname(os.path.realpath(__file__))

    dataset_dir=tmpdir+'/ironclust_dataset'
    # Generate three files in the dataset directory: raw.mda, geom.csv, params.json
    si.MdaRecordingExtractor.writeRecording(recording=recording,save_path=dataset_dir)
        
    samplerate=recording.getSamplingFrequency()

    print('Reading timeseries header...')
    HH=mdaio.readmda_header(dataset_dir+'/raw.mda')
    num_channels=HH.dims[0]
    num_timepoints=HH.dims[1]
    duration_minutes=num_timepoints/samplerate/60
    print('Num. channels = {}, Num. timepoints = {}, duration = {} minutes'.format(num_channels,num_timepoints,duration_minutes))

    print('Creating .params file...')
    txt=''
    txt+='samplerate={}\n'.format(samplerate)
    txt+='detect_sign={}\n'.format(detect_sign)
    txt+='adjacency_radius={}\n'.format(adjacency_radius)
    txt+='detect_threshold={}\n'.format(detect_threshold)
    txt+='merge_thresh={}\n'.format(merge_thresh)
    txt+='freq_min={}\n'.format(freq_min)
    txt+='freq_max={}\n'.format(freq_max)    
    txt+='pc_per_chan={}\n'.format(pc_per_chan)
    txt+='prm_template_name={}\n'.format(prm_template_name)
    _write_text_file(dataset_dir+'/argfile.txt',txt)
        
    print('Running IronClust...')
    cmd_path = "addpath('{}', '{}/matlab', '{}/mdaio');".format(ironclust_src, ironclust_src, ironclust_src)
    #"p_ironclust('$(tempdir)','$timeseries$','$geom$','$prm$','$firings_true$','$firings_out$','$(argfile)');"
    cmd_call = "p_ironclust('{}', '{}', '{}', '', '', '{}', '{}');"\
        .format(tmpdir, dataset_dir+'/raw.mda', dataset_dir+'/geom.csv', tmpdir+'/firings.mda', dataset_dir+'/argfile.txt')
    cmd='matlab -nosplash -nodisplay -r "{} {} quit;"'.format(cmd_path, cmd_call)
    print(cmd)
    retcode=_run_command_and_print_output(cmd)

    if retcode != 0:
        raise Exception('IronClust returned a non-zero exit code')

    # parse output
    result_fname=tmpdir+'/firings.mda'
    if not os.path.exists(result_fname):
        raise Exception('Result file does not exist: '+ result_fname)
    
    firings=mdaio.readmda(result_fname)
    sorting=si.NumpySortingExtractor()
    sorting.setTimesLabels(firings[1,:],firings[2,:])
    return sorting
コード例 #27
0
ファイル: utils.py プロジェクト: seankmartin/spiketoolkit
def create_signal_with_known_waveforms(n_channels=4,
                                       n_waveforms=2,
                                       n_wf_samples=100,
                                       duration=5,
                                       fs=30000):
    '''
    Creates stereotyped recording, sorting, with waveforms, templates, and max_chans
    '''
    a_min = [-200, -50]
    a_max = [10, 50]
    wfs = []

    # gen waveforms
    for w in range(n_waveforms):
        amp_min = np.random.randint(a_min[0], a_min[1])
        amp_max = np.random.randint(a_max[0], a_max[1])

        wf = create_wf(amp_min, amp_max, n_wf_samples)
        wfs.append(wf)

    # gen templates
    templates = []
    max_chans = []
    for wf in wfs:
        found = False
        while not found:
            template, amps, found = generate_template_with_random_amps(
                n_channels, wf)
        templates.append(template)
        max_chans.append(np.argmax(amps))

    templates = np.array(templates)
    n_samples = int(fs * duration)

    # gen spiketrains
    interval = 10 * n_wf_samples
    times = np.arange(interval, duration * fs - interval, interval).astype(int)
    labels = np.zeros(len(times)).astype(int)
    for i, wf in enumerate(wfs):
        labels[i::len(wfs)] = i

    timeseries = np.zeros((n_channels, n_samples))
    waveforms = []
    amplitudes = []
    for i, tem in enumerate(templates):
        idxs = np.where(labels == i)
        wav = []
        amps = []
        for t in times[idxs]:
            rand_val = np.random.randn() * 0.01 + 1
            timeseries[:, t - n_wf_samples // 2:t +
                       n_wf_samples // 2] = rand_val * tem
            wav.append(rand_val * tem)
            amps.append(np.min(rand_val * tem))
        wav = np.array(wav)
        amps = np.array(amps)
        waveforms.append(wav)
        amplitudes.append(amps)

    rec = se.NumpyRecordingExtractor(timeseries=timeseries,
                                     sampling_frequency=fs)
    sort = se.NumpySortingExtractor()
    sort.set_times_labels(times=times, labels=labels)
    sort.set_sampling_frequency(fs)

    return rec, sort, waveforms, templates, max_chans, amplitudes
コード例 #28
0
ファイル: kilosort2.py プロジェクト: samuelgarcia/spikeforest
def kilosort2_helper(
        *,
        recording,  # Recording object
        tmpdir,  # Temporary working directory
        detect_sign=-1,  # Polarity of the spikes, -1, 0, or 1
        adjacency_radius=-1,  # Channel neighborhood adjacency radius corresponding to geom file
        detect_threshold=6,  # Threshold for detection
        merge_thresh=.98,  # Cluster merging threhold 0..1
        freq_min=150,  # Lower frequency limit for band-pass filter
        freq_max=6000,  # Upper frequency limit for band-pass filter
        pc_per_chan=3,  # number of PC per chan
        minFR=1 / 50):

    # # TODO: do not require ks2 to depend on irc -- rather, put all necessary .m code in the spikeforest repo
    # ironclust_path = os.environ.get('IRONCLUST_PATH_DEV', None)
    # if ironclust_path:
    #     print('Using ironclust from IRONCLUST_PATH_DEV directory: {}'.format(ironclust_path))
    # else:
    #     try:
    #         print('Auto-installing ironclust.')
    #         ironclust_path = install_ironclust(commit='042b600b014de13f6d11d3b4e50e849caafb4709')
    #     except:
    #         traceback.print_exc()
    #         raise Exception('Problem installing ironclust. You can set the IRONCLUST_PATH_DEV to force to use a particular path.')
    # print('For kilosort2, using ironclust utility functions from: {}'.format(ironclust_path))

    kilosort2_path = os.environ.get('KILOSORT2_PATH_DEV', None)
    if kilosort2_path:
        print('Using kilosort2 from KILOSORT2_PATH_DEV directory: {}'.format(
            kilosort2_path))
    else:
        try:
            print('Auto-installing kilosort2.')
            kilosort2_path = KiloSort2.install()
        except:
            traceback.print_exc()
            raise Exception(
                'Problem installing kilosort2. You can set the KILOSORT2_PATH_DEV to force to use a particular path.'
            )
    print('Using kilosort2 from: {}'.format(kilosort2_path))

    source_dir = os.path.dirname(os.path.realpath(__file__))

    dataset_dir = tmpdir + '/kilosort2_dataset'
    # Generate three files in the dataset directory: raw.mda, geom.csv, params.json
    SFMdaRecordingExtractor.write_recording(recording=recording,
                                            save_path=dataset_dir,
                                            _preserve_dtype=True)

    samplerate = recording.get_sampling_frequency()

    print('Reading timeseries header...')
    HH = mdaio.readmda_header(dataset_dir + '/raw.mda')
    num_channels = HH.dims[0]
    num_timepoints = HH.dims[1]
    duration_minutes = num_timepoints / samplerate / 60
    print('Num. channels = {}, Num. timepoints = {}, duration = {} minutes'.
          format(num_channels, num_timepoints, duration_minutes))

    print('Creating argfile.txt file...')
    txt = ''
    txt += 'samplerate={}\n'.format(samplerate)
    txt += 'detect_sign={}\n'.format(detect_sign)
    txt += 'adjacency_radius={}\n'.format(adjacency_radius)
    txt += 'detect_threshold={}\n'.format(detect_threshold)
    txt += 'merge_thresh={}\n'.format(merge_thresh)
    txt += 'freq_min={}\n'.format(freq_min)
    txt += 'freq_max={}\n'.format(freq_max)
    txt += 'pc_per_chan={}\n'.format(pc_per_chan)
    txt += 'minFR={}\n'.format(minFR)
    _write_text_file(dataset_dir + '/argfile.txt', txt)

    print('Running Kilosort2 in {tmpdir}...'.format(tmpdir=tmpdir))
    cmd = '''
        addpath('{source_dir}');
        addpath('{source_dir}/mdaio')
        try
            p_kilosort2('{ksort}', '{tmpdir}', '{raw}', '{geom}', '{firings}', '{arg}');
        catch
            quit(1);
        end
        quit(0);
        '''
    cmd = cmd.format(source_dir=source_dir,
                     ksort=kilosort2_path,
                     tmpdir=tmpdir,
                     raw=dataset_dir + '/raw.mda',
                     geom=dataset_dir + '/geom.csv',
                     firings=tmpdir + '/firings.mda',
                     arg=dataset_dir + '/argfile.txt')
    matlab_cmd = mlpr.ShellScript(cmd,
                                  script_path=tmpdir + '/run_kilosort2.m',
                                  keep_temp_files=True)
    matlab_cmd.write()
    shell_cmd = '''
        #!/bin/bash
        cd {tmpdir}
        echo '=====================' `date` '====================='
        matlab -nosplash -nodisplay -r run_kilosort2
    '''.format(tmpdir=tmpdir)
    shell_cmd = mlpr.ShellScript(shell_cmd,
                                 script_path=tmpdir + '/run_kilosort2.sh',
                                 keep_temp_files=True)
    shell_cmd.write(tmpdir + '/run_kilosort2.sh')
    shell_cmd.start()
    retcode = shell_cmd.wait()

    if retcode != 0:
        raise Exception('kilosort2 returned a non-zero exit code')

    # parse output
    result_fname = tmpdir + '/firings.mda'
    if not os.path.exists(result_fname):
        raise Exception('Result file does not exist: ' + result_fname)

    firings = mdaio.readmda(result_fname)
    sorting = se.NumpySortingExtractor()
    sorting.set_times_labels(firings[1, :], firings[2, :])
    return sorting
コード例 #29
0
def toy_example(duration: float = 10.,
                num_channels: int = 4,
                sampling_frequency: float = 30000.,
                K: int = 10,
                dumpable: bool = False,
                dump_folder: Optional[Union[str, Path]] = None,
                seed: Optional[int] = None):
    """
    Create toy recording and sorting extractors.

    Parameters
    ----------
    duration: float
        Duration in s (default 10)
    num_channels: int
        Number of channels (default 4)
    sampling_frequency: float
        Sampling frequency (default 30000)
    K: int
        Number of units (default 10)
    dumpable: bool
        If True, objects are dumped to file and become 'dumpable'
    dump_folder: str or Path
        Path to dump folder (if None, 'test' is used
    seed: int
        Seed for random initialization

    Returns
    -------
    recording: RecordingExtractor
        The output recording extractor. If dumpable is False it's a NumpyRecordingExtractor, otherwise it's an
        MdaRecordingExtractor
    sorting: SortingExtractor
        The output sorting extractor. If dumpable is False it's a NumpyRecordingExtractor, otherwise it's an
        NpzSortingExtractor
    """
    upsamplefac = 13
    waveforms, geom = synthesize_random_waveforms(K=K,
                                                  M=num_channels,
                                                  average_peak_amplitude=-100,
                                                  upsamplefac=upsamplefac,
                                                  seed=seed)
    times, labels = synthesize_random_firings(
        K=K,
        duration=duration,
        sampling_frequency=sampling_frequency,
        seed=seed)
    labels = labels.astype(np.int64)
    SX = se.NumpySortingExtractor()
    SX.set_times_labels(times, labels)
    X = synthesize_timeseries(sorting=SX,
                              waveforms=waveforms,
                              noise_level=10,
                              sampling_frequency=sampling_frequency,
                              duration=duration,
                              waveform_upsamplefac=upsamplefac,
                              seed=seed)
    SX.set_sampling_frequency(sampling_frequency)

    RX = se.NumpyRecordingExtractor(timeseries=X,
                                    sampling_frequency=sampling_frequency,
                                    geom=geom)
    RX.is_filtered = True

    if dumpable:
        if dump_folder is None:
            dump_folder = 'toy_example'
        dump_folder = Path(dump_folder)

        se.MdaRecordingExtractor.write_recording(RX, dump_folder)
        RX = se.MdaRecordingExtractor(dump_folder)
        se.NpzSortingExtractor.write_sorting(SX, dump_folder / 'sorting.npz')
        SX = se.NpzSortingExtractor(dump_folder / 'sorting.npz')

    return RX, SX
コード例 #30
0
 def test_empty_write(self):
     sorting_empty = se.NumpySortingExtractor()
     se.NpzSortingExtractor.write_sorting(sorting_empty, 'test_NpzSortingExtractors_empty.npz')