def main(): samplerate = 30000 duration_sec = 10 # number of timepoints true_firing_rates_hz = [1, 2, 3, 4, 5] approx_false_negative_rates = [0, 0.1, 0.2, 0.3, 0.4] approx_false_positive_rates = [0, 0.2, 0.1, 0.4, 0.3] extra_unit_firing_rates_hz = [0.5, 1, 1.5] num_timepoints = samplerate * duration_sec sorting_true = se.NumpySortingExtractor() sorting = se.NumpySortingExtractor() for ii in range(len(true_firing_rates_hz)): num_events = int(duration_sec * true_firing_rates_hz[ii]) times0 = np.random.choice(np.arange(num_timepoints), size=num_events, replace=False).astype(float) num_hits = int((1 - approx_false_negative_rates[ii]) * num_events) hits = np.random.choice(times0, size=num_hits, replace=False) num_extra = int(approx_false_positive_rates[ii] * num_events) extra = np.random_choice(np.arange(num_timepoints), size=num_extra, replace=False).astype(float) times1 = np.sort(hits + extra) sorting_true.add_unit(ii + 1, times0) sorting.add_unit(ii + 1, times1) for ii in range(len(extra_unit_firing_rates_hz)): num_events = int(duration_sec * extra_unit_firing_rates_hz[ii]) times0 = np.random.choice(np.arange(num_timepoints), size=num_events, replace=False).astype(float) sorting.add_unit(len(true_firing_rates_hz) + ii + 1, times0)
def _create_example(self, seed): channel_ids = [0, 1, 2, 3] num_channels = 4 num_frames = 10000 sampling_frequency = 30000 X = np.random.RandomState(seed=seed).normal(0, 1, (num_channels, num_frames)) geom = np.random.RandomState(seed=seed).normal(0, 1, (num_channels, 2)) X = (X * 100).astype(int) RX = se.NumpyRecordingExtractor(timeseries=X, sampling_frequency=sampling_frequency, geom=geom) RX2 = se.NumpyRecordingExtractor(timeseries=X, sampling_frequency=sampling_frequency, geom=geom) RX3 = se.NumpyRecordingExtractor(timeseries=X, sampling_frequency=sampling_frequency, geom=geom) SX = se.NumpySortingExtractor() spike_times = [200, 300, 400] train1 = np.sort(np.rint(np.random.RandomState(seed=seed).uniform(0, num_frames, spike_times[0])).astype(int)) SX.add_unit(unit_id=1, times=train1) SX.add_unit(unit_id=2, times=np.sort(np.random.RandomState(seed=seed).uniform(0, num_frames, spike_times[1]))) SX.add_unit(unit_id=3, times=np.sort(np.random.RandomState(seed=seed).uniform(0, num_frames, spike_times[2]))) SX.set_unit_property(unit_id=1, property_name='stability', value=80) SX.set_sampling_frequency(sampling_frequency) SX2 = se.NumpySortingExtractor() spike_times2 = [100, 150, 450] train2 = np.rint(np.random.RandomState(seed=seed).uniform(0, num_frames, spike_times2[0])).astype(int) SX2.add_unit(unit_id=3, times=train2) SX2.add_unit(unit_id=4, times=np.random.RandomState(seed=seed).uniform(0, num_frames, spike_times2[1])) SX2.add_unit(unit_id=5, times=np.random.RandomState(seed=seed).uniform(0, num_frames, spike_times2[2])) SX2.set_unit_property(unit_id=4, property_name='stability', value=80) SX2.set_unit_spike_features(unit_id=3, feature_name='widths', value=np.asarray([3] * spike_times2[0])) RX.set_channel_locations([0, 0], channel_ids=0) for i, unit_id in enumerate(SX2.get_unit_ids()): SX2.set_unit_property(unit_id=unit_id, property_name='shared_unit_prop', value=i) SX2.set_unit_spike_features(unit_id=unit_id, feature_name='shared_unit_feature', value=np.asarray([i] * spike_times2[i])) for i, channel_id in enumerate(RX.get_channel_ids()): RX.set_channel_property(channel_id=channel_id, property_name='shared_channel_prop', value=i) SX3 = se.NumpySortingExtractor() train3 = np.asarray([1, 20, 21, 35, 38, 45, 46, 47]) SX3.add_unit(unit_id=0, times=train3) features3 = np.asarray([0, 5, 10, 15, 20, 25, 30, 35]) features4 = np.asarray([0, 10, 20, 30]) feature4_idx = np.asarray([0, 2, 4, 6]) SX3.set_unit_spike_features(unit_id=0, feature_name='dummy', value=features3) SX3.set_unit_spike_features(unit_id=0, feature_name='dummy2', value=features4, indexes=feature4_idx) example_info = dict( channel_ids=channel_ids, num_channels=num_channels, num_frames=num_frames, sampling_frequency=sampling_frequency, unit_ids=[1, 2, 3], train1=train1, train2=train2, train3=train3, features3=features3, unit_prop=80, channel_prop=(0, 0) ) return (RX, RX2, RX3, SX, SX2, SX3, example_info)
def make_sorting(times1, labels1, times2, labels2): gt_sorting = se.NumpySortingExtractor() tested_sorting = se.NumpySortingExtractor() gt_sorting.set_times_labels(np.array(times1), np.array(labels1)) tested_sorting.set_times_labels(np.array(times2), np.array(labels2)) gt_sorting.set_sampling_frequency(30000) tested_sorting.set_sampling_frequency(30000) return gt_sorting, tested_sorting
def make_sorting(times1, labels1, times2, labels2, times3, labels3): sorting1 = se.NumpySortingExtractor() sorting2 = se.NumpySortingExtractor() sorting3 = se.NumpySortingExtractor() sorting1.set_times_labels(np.array(times1), np.array(labels1)) sorting2.set_times_labels(np.array(times2), np.array(labels2)) sorting3.set_times_labels(np.array(times3), np.array(labels3)) return sorting1, sorting2, sorting3
def __init__(self, arg, samplerate=None): super().__init__() if (isinstance(arg, dict)) and ('sorting_format' in arg): obj = dict(arg) else: obj = _create_object_for_arg(arg, samplerate=samplerate) assert obj is not None, f'Unable to create sorting from arg: {arg}' self._object: dict = obj if 'firings' in self._object: sorting_format = 'mda' data={'firings': self._object['firings'], 'samplerate': self._object.get('samplerate', 30000)} else: sorting_format = self._object['sorting_format'] data: dict = self._object['data'] if sorting_format == 'mda': firings_path = kp.load_file(data['firings']) assert firings_path is not None, f'Unable to load firings file: {data["firings"]}' self._sorting: se.SortingExtractor = MdaSortingExtractor(firings_file=firings_path, samplerate=data['samplerate']) elif sorting_format == 'h5_v1': h5_path = kp.load_file(data['h5_path']) self._sorting = H5SortingExtractorV1(h5_path=h5_path) elif sorting_format == 'npy1': times_npy = kp.load_npy(data['times_npy_uri']) labels_npy = kp.load_npy(data['labels_npy_uri']) samplerate = data['samplerate'] S = se.NumpySortingExtractor() S.set_sampling_frequency(samplerate) S.set_times_labels(times_npy.ravel(), labels_npy.ravel()) self._sorting = S elif sorting_format == 'snippets1': S = Snippets1SortingExtractor(snippets_h5_uri = data['snippets_h5_uri'], p2p=True) self._sorting = S elif sorting_format == 'npy2': npz = kp.load_npy(data['npz_uri']) times_npy = npz['spike_indexes'] labels_npy = npz['spike_labels'] samplerate = float(npz['sampling_frequency']) S = se.NumpySortingExtractor() S.set_sampling_frequency(samplerate) S.set_times_labels(times_npy.ravel(), labels_npy.ravel()) self._sorting = S elif sorting_format == 'nwb': from .nwbextractors import NwbSortingExtractor path0 = kp.load_file(data['path']) self._sorting: se.SortingExtractor = NwbSortingExtractor(path0) elif sorting_format == 'in_memory': S = get_in_memory_object(data) if S is None: raise Exception('Unable to find in-memory object for sorting') self._sorting = S else: raise Exception(f'Unexpected sorting format: {sorting_format}') self.copy_unit_properties(sorting=self._sorting)
def _create_example(self): channel_ids = [0, 1, 2, 3] num_channels = 4 num_frames = 10000 samplerate = 30000 X = np.random.normal(0, 1, (num_channels, num_frames)) geom = np.random.normal(0, 1, (num_channels, 2)) X = (X * 100).astype(int) RX = se.NumpyRecordingExtractor(timeseries=X, samplerate=samplerate, geom=geom) RX2 = se.NumpyRecordingExtractor(timeseries=X, samplerate=samplerate, geom=geom) RX3 = se.NumpyRecordingExtractor(timeseries=X, samplerate=samplerate, geom=geom) SX = se.NumpySortingExtractor() spike_times = [200, 300, 400] train1 = np.sort( np.rint(np.random.uniform(0, num_frames, spike_times[0])).astype(int)) SX.add_unit(unit_id=1, times=train1) SX.add_unit(unit_id=2, times=np.sort( np.random.uniform(0, num_frames, spike_times[1]))) SX.add_unit(unit_id=3, times=np.sort( np.random.uniform(0, num_frames, spike_times[2]))) SX.set_unit_property(unit_id=1, property_name='stablility', value=80) SX.set_sampling_frequency(samplerate) SX2 = se.NumpySortingExtractor() spike_times2 = [100, 150, 450] train2 = np.rint(np.random.uniform(0, num_frames, spike_times[0])).astype(int) SX2.add_unit(unit_id=3, times=train2) SX2.add_unit(unit_id=4, times=np.random.uniform(0, num_frames, spike_times2[1])) SX2.add_unit(unit_id=5, times=np.random.uniform(0, num_frames, spike_times2[2])) SX2.set_unit_property(unit_id=4, property_name='stablility', value=80) RX.set_channel_property(channel_id=0, property_name='location', value=(0, 0)) example_info = dict(channel_ids=channel_ids, num_channels=num_channels, num_frames=num_frames, samplerate=samplerate, unit_ids=[1, 2, 3], train1=train1, unit_prop=80, channel_prop=(0, 0)) return (RX, RX2, RX3, SX, SX2, example_info)
def mountainsort4(*,recording,detect_sign,clip_size=50,adjacency_radius=-1,detect_threshold=3,detect_interval=10,num_workers=None): if num_workers is None: num_workers=int((multiprocessing.cpu_count()+1)/2) print('Using {} workers.'.format(num_workers)) MS4=MountainSort4() MS4.setRecording(recording) geom=_get_geom_from_recording(recording) MS4.setGeom(geom) MS4.setSortingOpts( clip_size=clip_size, adjacency_radius=adjacency_radius, detect_sign=detect_sign, detect_interval=detect_interval, detect_threshold=detect_threshold ) tmpdir = tempfile.mkdtemp() MS4.setNumWorkers(num_workers) print('Using tmpdir: '+tmpdir) MS4.setTemporaryDirectory(tmpdir) try: MS4.sort() except: print('Cleaning tmpdir:: '+tmpdir) shutil.rmtree(tmpdir) raise print('Cleaning tmpdir::::: '+tmpdir) shutil.rmtree(tmpdir) times,labels,channels=MS4.eventTimesLabelsChannels() output=se.NumpySortingExtractor() output.set_times_labels(times=times,labels=labels) return output
def setUp(self): M = 4 N = 10000 N_ttl = 50 seed = 0 sampling_frequency = 30000 X = np.random.RandomState(seed=seed).normal(0, 1, (M, N)) geom = np.random.RandomState(seed=seed).normal(0, 1, (M, 2)) self._X = X self._geom = geom self._sampling_frequency = sampling_frequency self.RX = se.NumpyRecordingExtractor( timeseries=X, sampling_frequency=sampling_frequency, geom=geom) self._ttl_frames = np.sort(np.random.permutation(N)[:N_ttl]) self.RX.set_ttls(self._ttl_frames) self.SX = se.NumpySortingExtractor() L = 200 self._train1 = np.rint( np.random.RandomState(seed=seed).uniform(0, N, L)).astype(int) self.SX.add_unit(unit_id=1, times=self._train1) self.SX.add_unit(unit_id=2, times=np.random.RandomState(seed=seed).uniform( 0, N, L)) self.SX.add_unit(unit_id=3, times=np.random.RandomState(seed=seed).uniform( 0, N, L))
def toy_example(duration=10, num_channels=4, samplerate=30000.0, K=10, seed=None): upsamplefac = 13 waveforms, geom = synthesize_random_waveforms(K=K, M=num_channels, average_peak_amplitude=-100, upsamplefac=upsamplefac, seed=seed) times, labels = synthesize_random_firings(K=K, duration=duration, samplerate=samplerate, seed=seed) labels = labels.astype(np.int64) SX = se.NumpySortingExtractor() SX.set_times_labels(times, labels) X = synthesize_timeseries(sorting=SX, waveforms=waveforms, noise_level=10, samplerate=samplerate, duration=duration, waveform_upsamplefac=upsamplefac) SX.set_sampling_frequency(samplerate) RX = se.NumpyRecordingExtractor(timeseries=X, samplerate=samplerate, geom=geom) return (RX, SX)
def get_unmatched_sorting(sx1, sx2, ids1, ids2): ret = se.NumpySortingExtractor() for ii in range(len(ids1)): id1 = ids1[ii] id2 = ids2[ii] train1 = sx1.get_unit_spike_train(unit_id=id1) train2 = sx2.get_unit_spike_train(unit_id=id2) train = get_unmatched_times(train1, train2, delta=100) ret.add_unit(id1, train) return ret
def get_unmatched_sorting(sx1, sx2, ids1, ids2): # spikes in first sorting that are not matched to spikes in second sorting ret = se.NumpySortingExtractor() for ii in range(len(ids1)): id1 = ids1[ii] id2 = ids2[ii] train1 = sx1.get_unit_spike_train(unit_id=id1) train2 = sx2.get_unit_spike_train(unit_id=id2) train = get_unmatched_times(train1, train2, delta=100) ret.addUnit(id1, train) return ret
def __init__(self, arg, samplerate=None): super().__init__() if (isinstance(arg, dict)) and ('sorting_format' in arg): obj = dict(arg) else: obj = _create_object_for_arg(arg, samplerate=samplerate) assert obj is not None, f'Unable to create sorting from arg: {arg}' self._object: dict = obj sorting_format = self._object['sorting_format'] data: dict = self._object['data'] if sorting_format == 'mda': firings_path = kp.load_file(data['firings']) assert firings_path is not None, f'Unable to load firings file: {data["firings"]}' self._sorting: se.SortingExtractor = MdaSortingExtractor( firings_file=firings_path, samplerate=data['samplerate']) elif sorting_format == 'h5_v1': h5_path = kp.load_file(data['h5_path']) self._sorting = H5SortingExtractorV1(h5_path=h5_path) elif sorting_format == 'npy1': times_npy = kp.load_npy(data['times_npy_uri']) labels_npy = kp.load_npy(data['labels_npy_uri']) samplerate = data['samplerate'] S = se.NumpySortingExtractor() S.set_sampling_frequency(samplerate) S.set_times_labels(times_npy.ravel(), labels_npy.ravel()) self._sorting = S elif sorting_format == 'npy2': npz = kp.load_npy(data['npz_uri']) times_npy = npz['spike_indexes'] labels_npy = npz['spike_labels'] samplerate = float(npz['sampling_frequency']) S = se.NumpySortingExtractor() S.set_sampling_frequency(samplerate) S.set_times_labels(times_npy.ravel(), labels_npy.ravel()) self._sorting = S else: raise Exception(f'Unexpected sorting format: {sorting_format}') self.copy_unit_properties(sorting=self._sorting)
def test_npz_extractor(self): path = self.test_dir + '/sorting.npz' se.NpzSortingExtractor.write_sorting(self.SX, path) SX_npz = se.NpzSortingExtractor(path) # empty write sorting_empty = se.NumpySortingExtractor() path_empty = self.test_dir + '/sorting_empty.npz' se.NpzSortingExtractor.write_sorting(sorting_empty, path_empty) check_sorting_return_types(SX_npz) check_sortings_equal(self.SX, SX_npz) check_dumping(SX_npz)
def get_result_from_folder(output_folder): # overwrite the SorterBase.get_result from mountainlab_pytools import mdaio result_fname = Path(output_folder) / 'firings.mda' assert result_fname.exists(), 'Result file does not exist: {}'.format( str(result_fname)) firings = mdaio.readmda(str(result_fname)) sorting = se.NumpySortingExtractor() sorting.set_times_labels(firings[1, :], firings[2, :]) return sorting
def toy_example1(duration=10, num_channels=4, samplerate=30000, K=10, firing_rates=None, noise_level=10): upsamplefac = 13 waveforms, geom = synthesize_random_waveforms(K=K, M=num_channels, average_peak_amplitude=-100, upsamplefac=upsamplefac) times, labels = synthesize_random_firings(K=K, duration=duration, samplerate=samplerate, firing_rates=firing_rates) labels = labels.astype(np.int64) OX = se.NumpySortingExtractor() OX.set_times_labels(times, labels) X = synthesize_timeseries(sorting=OX, waveforms=waveforms, noise_level=noise_level, samplerate=samplerate, duration=duration, waveform_upsamplefac=upsamplefac) IX = se.NumpyRecordingExtractor(timeseries=X, samplerate=samplerate, geom=geom) return (IX, OX)
def setup_study(): rec_names = [ '20160415_patch2', '20160426_patch2', '20160426_patch3', '20170621_patch1', '20170713_patch1', '20170725_patch1', '20170728_patch2', '20170803_patch1', ] gt_dict = {} for rec_name in rec_names: # find raw file dirname = recording_folder + rec_name + '/' for f in os.listdir(dirname): if f.endswith('.raw') and not f.endswith('juxta.raw'): mea_filename = dirname + f # raw files have an internal offset that depend on the channel count # a simple built header can be parsed to get it with open(mea_filename.replace('.raw', '.txt'), mode='r') as f: offset = int(re.findall('padding = (\d+)', f.read())[0]) # recording rec = se.BinDatRecordingExtractor(mea_filename, 20000., 256, 'uint16', offset=offset, frames_first=True) # this reduce channel count to 252 rec = se.load_probe_file(rec, basedir + 'mea_256.prb') # gt sorting gt_indexes = np.fromfile(ground_truth_folder + rec_name + '/juxta_peak_indexes.raw', dtype='int64') sorting_gt = se.NumpySortingExtractor() sorting_gt.set_times_labels(gt_indexes, np.zeros(gt_indexes.size, dtype='int64')) sorting_gt.set_sampling_frequency(20000.0) gt_dict[rec_name] = (rec, sorting_gt) study = GroundTruthStudy.setup(study_folder, gt_dict)
def setUp(self): M = 4 N = 10000 samplerate = 30000 X = np.random.normal(0, 1, (M, N)) geom = np.random.normal(0, 1, (M, 2)) self._X = X self._geom = geom self._samplerate = samplerate self.RX = se.NumpyRecordingExtractor(timeseries=X, samplerate=samplerate, geom=geom) self.SX = se.NumpySortingExtractor() L = 200 self._train1 = np.rint(np.random.uniform(0, N, L)).astype(int) self.SX.add_unit(unit_id=1, times=self._train1) self.SX.add_unit(unit_id=2, times=np.random.uniform(0, N, L)) self.SX.add_unit(unit_id=3, times=np.random.uniform(0, N, L))
def toy_example(duration=10, num_channels=4, sampling_frequency=30000.0, K=10, dumpable=False, dump_folder=None, seed=None): upsamplefac = 13 waveforms, geom = synthesize_random_waveforms(K=K, M=num_channels, average_peak_amplitude=-100, upsamplefac=upsamplefac, seed=seed) times, labels = synthesize_random_firings( K=K, duration=duration, sampling_frequency=sampling_frequency, seed=seed) labels = labels.astype(np.int64) SX = se.NumpySortingExtractor() SX.set_times_labels(times, labels) X = synthesize_timeseries(sorting=SX, waveforms=waveforms, noise_level=10, sampling_frequency=sampling_frequency, duration=duration, waveform_upsamplefac=upsamplefac, seed=seed) SX.set_sampling_frequency(sampling_frequency) RX = se.NumpyRecordingExtractor(timeseries=X, sampling_frequency=sampling_frequency, geom=geom) RX.is_filtered = True if dumpable: if dump_folder is None: dump_folder = 'toy_example' dump_folder = Path(dump_folder) se.MdaRecordingExtractor.write_recording(RX, dump_folder) RX = se.MdaRecordingExtractor(dump_folder) se.NpzSortingExtractor.write_sorting(SX, dump_folder / 'sorting.npz') SX = se.NpzSortingExtractor(dump_folder / 'sorting.npz') return RX, SX
def gen_synth_datasets(datasets, *, outdir, num_channels=4, upsamplefac=13, samplerate=30000, average_peak_amplitude=-100): if not os.path.exists(outdir): os.mkdir(outdir) for ds in datasets: ds_name = ds['name'] print(ds_name) K = ds['K'] duration = ds['duration'] noise_level = ds['noise_level'] waveforms, geom = synthesize_random_waveforms( K=K, M=num_channels, average_peak_amplitude=average_peak_amplitude, upsamplefac=upsamplefac) times, labels = synthesize_random_firings(K=K, duration=duration, samplerate=samplerate) labels = labels.astype(np.int64) OX = si.NumpySortingExtractor() OX.setTimesLabels(times, labels) X = synthesize_timeseries(output_extractor=OX, waveforms=waveforms, noise_level=noise_level, samplerate=samplerate, duration=duration, waveform_upsamplefac=upsamplefac) IX = si.NumpyRecordingExtractor(timeseries=X, samplerate=samplerate, geom=geom) si.MdaRecordingExtractor.writeRecording(IX, outdir + '/{}'.format(ds_name)) si.MdaSortingExtractor.writeSorting( OX, outdir + '/{}/firings_true.mda'.format(ds_name)) # write json with two fields print('Done.')
def get_sorting_extractor(self, key, sort_interval): #TODO: replace with spikeinterface call if possible """Generates a numpy sorting extractor given a key that retrieves a SpikeSorting and a specified sort interval :param key: key for a single SpikeSorting :type key: dict :param sort_interval: [start_time, end_time] :type sort_interval: numpy array :return: a spikeextractors sorting extractor with the sorting information """ # get the units object from the NWB file that the data are stored in. units = (SpikeSorting & key).fetch_nwb()[0]['units'].to_dataframe() unit_timestamps = [] unit_labels = [] # TODO: do something more efficient here; note that searching for maching sort_intervals within pandas doesn't seem to work for index, unit in units.iterrows(): if np.ndarray.all(np.ravel(unit['sort_interval']) == sort_interval): unit_timestamps.extend(unit['spike_times']) unit_labels.extend([index]*len(unit['spike_times'])) output=se.NumpySortingExtractor() output.set_times_labels(times=np.asarray(unit_timestamps),labels=np.asarray(unit_labels)) return output
def detect_spikes(recording, channel_ids=None, detect_threshold=5, n_pad_ms=2, upsample=1, detect_sign=-1, min_diff_samples=5, parallel=False, n_jobs=-1): ''' Detects spikes per channel. Parameters ---------- recording: RecordingExtractor The recording extractor object channel_ids: list or None List of channels to perform detection. If None all channels are used detect_threshold: float Threshold in MAD to detect peaks n_pad_ms: float Time in ms to find absolute peak around detected peak upsample: int The detected waveforms are upsampled 'upsample' times (default=1) detect_sign: int Sign of the detection: -1 (negative), 1 (positive), 0 (both) min_diff_samples: int Minimum interval to skip consecutive spikes (default=5) parallel: bool If True, each channel is run in parallel n_jobs: int Number of jobs when parallel Returns ------- sorting_detected: SortingExtractor The sorting extractor object with the detected spikes. Unit ids are the same as channel ids and units have the 'channel' property to specify which channel they correspond to ''' spike_times = [] labels = [] n_pad_samples = int(n_pad_ms * recording.get_sampling_frequency() / 1000) if channel_ids is None: channel_ids = recording.get_channel_ids() else: assert np.all([ch in recording.get_channel_ids() for ch in channel_ids]), "Not all 'channel_ids' are in the" \ "recording." if parallel: output = Parallel(n_jobs=n_jobs)( delayed(_detect_and_align_peaks_single_channel)( recording, ch, detect_threshold, detect_sign, n_pad_samples, upsample, min_diff_samples) for ch in channel_ids) for o in output: spike_times.append(o[0]) labels.append(o[1]) else: for ch in channel_ids: peak_times, label = _detect_and_align_peaks_single_channel( recording, ch, detect_threshold, detect_sign, n_pad_samples, upsample, min_diff_samples) spike_times.append(peak_times) labels.append(label) # create sorting extractor sorting = se.NumpySortingExtractor() labels_flat = np.array(list(itertools.chain(*labels))) times_flat = np.array(list(itertools.chain(*spike_times))) sorting.set_times_labels(times=times_flat, labels=labels_flat) for u in sorting.get_unit_ids(): sorting.set_unit_property(u, 'channel', u) return sorting
def waveclus_helper( *, recording, # Recording object tmpdir, # Temporary working directory params=dict(), **kwargs): waveclus_path = os.environ.get('WAVECLUS_PATH_DEV', None) if waveclus_path: print('Using waveclus from WAVECLUS_PATH_DEV directory: {}'.format( waveclus_path)) else: try: print('Auto-installing waveclus.') waveclus_path = install_waveclus( repo='https://github.com/csn-le/wave_clus.git', commit='248d15c7eaa2b45b15e4488dfb9b09bfe39f5341') except: traceback.print_exc() raise Exception( 'Problem installing waveclus. You can set the WAVECLUS_PATH_DEV to force to use a particular path.' ) print('Using waveclus from: {}'.format(waveclus_path)) dataset_dir = os.path.join(tmpdir, 'waveclus_dataset') # Generate three files in the dataset directory: raw.mda, geom.csv, params.json SFMdaRecordingExtractor.write_recording(recording=recording, save_path=dataset_dir, params=params, _preserve_dtype=True) samplerate = recording.get_sampling_frequency() print('Reading timeseries header...') raw_mda = os.path.join(dataset_dir, 'raw.mda') HH = mdaio.readmda_header(raw_mda) num_channels = HH.dims[0] num_timepoints = HH.dims[1] duration_minutes = num_timepoints / samplerate / 60 print('Num. channels = {}, Num. timepoints = {}, duration = {} minutes'. format(num_channels, num_timepoints, duration_minutes)) # new method source_path = os.path.dirname(os.path.realpath(__file__)) print('Running waveclus in {tmpdir}...'.format(tmpdir=tmpdir)) cmd = ''' addpath(genpath('{waveclus_path}'), '{source_path}', '{source_path}/mdaio'); try p_waveclus('{tmpdir}', '{dataset_dir}/raw.mda', '{tmpdir}/firings.mda', {samplerate}); catch fprintf('----------------------------------------'); fprintf(lasterr()); quit(1); end quit(0); ''' cmd = cmd.format(waveclus_path=waveclus_path, tmpdir=tmpdir, dataset_dir=dataset_dir, source_path=source_path, samplerate=samplerate) matlab_cmd = mlpr.ShellScript(cmd, script_path=tmpdir + '/run_waveclus.m', keep_temp_files=True) matlab_cmd.write() shell_cmd = ''' #!/bin/bash cd {tmpdir} matlab -nosplash -nodisplay -r run_waveclus '''.format(tmpdir=tmpdir) shell_cmd = mlpr.ShellScript(shell_cmd, script_path=tmpdir + '/run_waveclus.sh', keep_temp_files=True) shell_cmd.write(tmpdir + '/run_waveclus.sh') time_ = time.time() shell_cmd.start() retcode = shell_cmd.wait() print('#SF-SORTER-RUNTIME#{:.3f}#'.format(time_ - time.time())) if retcode != 0: raise Exception('waveclus returned a non-zero exit code') # parse output result_fname = tmpdir + '/firings.mda' if not os.path.exists(result_fname): raise Exception('Result file does not exist: ' + result_fname) firings = mdaio.readmda(result_fname) sorting = se.NumpySortingExtractor() sorting.set_times_labels(firings[1, :], firings[2, :]) return sorting
def ironclust_helper( *, recording, # Recording object tmpdir, # Temporary working directory params=dict(), ironclust_path, **kwargs): source_dir = os.path.dirname(os.path.realpath(__file__)) dataset_dir = tmpdir + '/ironclust_dataset' # Generate three files in the dataset directory: raw.mda, geom.csv, params.json SFMdaRecordingExtractor.write_recording( recording=recording, save_path=dataset_dir, params=params, _preserve_dtype=True) samplerate = recording.get_sampling_frequency() print('Reading timeseries header...') HH = mdaio.readmda_header(dataset_dir + '/raw.mda') num_channels = HH.dims[0] num_timepoints = HH.dims[1] duration_minutes = num_timepoints / samplerate / 60 print('Num. channels = {}, Num. timepoints = {}, duration = {} minutes'.format( num_channels, num_timepoints, duration_minutes)) print('Creating argfile.txt...') txt = '' for key0, val0 in kwargs.items(): txt += '{}={}\n'.format(key0, val0) txt += 'samplerate={}\n'.format(samplerate) if 'scale_factor' in params: txt += 'scale_factor={}\n'.format(params["scale_factor"]) _write_text_file(dataset_dir + '/argfile.txt', txt) # new method print('Running ironclust in {tmpdir}...'.format(tmpdir=tmpdir)) cmd = ''' addpath('{source_dir}'); addpath('{ironclust_path}', '{ironclust_path}/matlab', '{ironclust_path}/matlab/mdaio'); try p_ironclust('{tmpdir}', '{dataset_dir}/raw.mda', '{dataset_dir}/geom.csv', '', '', '{tmpdir}/firings.mda', '{dataset_dir}/argfile.txt'); catch fprintf('----------------------------------------'); fprintf(lasterr()); quit(1); end quit(0); ''' cmd = cmd.format(ironclust_path=ironclust_path, tmpdir=tmpdir, dataset_dir=dataset_dir, source_dir=source_dir) matlab_cmd = mlpr.ShellScript(cmd, script_path=tmpdir + '/run_ironclust.m', keep_temp_files=True) matlab_cmd.write() shell_cmd = ''' #!/bin/bash cd {tmpdir} matlab -nosplash -nodisplay -r run_ironclust '''.format(tmpdir=tmpdir) shell_cmd = mlpr.ShellScript(shell_cmd, script_path=tmpdir + '/run_ironclust.sh', keep_temp_files=True) shell_cmd.write(tmpdir + '/run_ironclust.sh') shell_cmd.start() retcode = shell_cmd.wait() if retcode != 0: raise Exception('ironclust returned a non-zero exit code') # parse output result_fname = tmpdir + '/firings.mda' if not os.path.exists(result_fname): raise Exception('Result file does not exist: ' + result_fname) firings = mdaio.readmda(result_fname) sorting = se.NumpySortingExtractor() sorting.set_times_labels(firings[1, :], firings[2, :]) return sorting
def jrclust_helper( *, recording, # Recording object tmpdir, # Temporary working directory params=dict(), **kwargs): jrclust_path = os.environ.get('JRCLUST_PATH_DEV', None) if jrclust_path: print('Using jrclust from JRCLUST_PATH_DEV directory: {}'.format( jrclust_path)) else: try: print('Auto-installing jrclust.') jrclust_path = install_jrclust( repo='https://github.com/JaneliaSciComp/JRCLUST.git', commit='3d2e75c0041dca2a9f273598750c6a14dbc4c1b8') except: traceback.print_exc() raise Exception( 'Problem installing jrclust. You can set the JRCLUST_PATH_DEV to force to use a particular path.' ) print('Using jrclust from: {}'.format(jrclust_path)) dataset_dir = os.path.join(tmpdir, 'jrclust_dataset') # Generate three files in the dataset directory: raw.mda, geom.csv, params.json SFMdaRecordingExtractor.write_recording(recording=recording, save_path=dataset_dir, params=params, _preserve_dtype=True) samplerate = recording.get_sampling_frequency() print('Reading timeseries header...') raw_mda = os.path.join(dataset_dir, 'raw.mda') HH = mdaio.readmda_header(raw_mda) num_channels = HH.dims[0] num_timepoints = HH.dims[1] duration_minutes = num_timepoints / samplerate / 60 print('Num. channels = {}, Num. timepoints = {}, duration = {} minutes'. format(num_channels, num_timepoints, duration_minutes)) print('Creating argfile.txt...') txt = '' for key0, val0 in kwargs.items(): txt += '{}={}\n'.format(key0, val0) if 'scale_factor' in params: txt += 'bitScaling={}\n'.format(params["scale_factor"]) txt += 'sampleRate={}\n'.format(samplerate) _write_text_file(dataset_dir + '/argfile.txt', txt) # new method source_path = os.path.dirname(os.path.realpath(__file__)) print('Running jrclust in {tmpdir}...'.format(tmpdir=tmpdir)) cmd = ''' addpath('{jrclust_path}', '{source_path}', '{source_path}/mdaio'); try p_jrclust('{tmpdir}', '{dataset_dir}/raw.mda', '{dataset_dir}/geom.csv', '{tmpdir}/firings.mda', '{dataset_dir}/argfile.txt'); catch fprintf('----------------------------------------'); fprintf(lasterr()); quit(1); end quit(0); ''' cmd = cmd.format(jrclust_path=jrclust_path, tmpdir=tmpdir, dataset_dir=dataset_dir, source_path=source_path) matlab_cmd = mlpr.ShellScript(cmd, script_path=tmpdir + '/run_jrclust.m', keep_temp_files=True) matlab_cmd.write() shell_cmd = ''' #!/bin/bash cd {tmpdir} matlab -nosplash -nodisplay -r run_jrclust '''.format(tmpdir=tmpdir) shell_cmd = mlpr.ShellScript(shell_cmd, script_path=tmpdir + '/run_jrclust.sh', keep_temp_files=True) shell_cmd.write(tmpdir + '/run_jrclust.sh') time_ = time.time() shell_cmd.start() retcode = shell_cmd.wait() print('#SF-SORTER-RUNTIME#{:.3f}#'.format(time_ - time.time())) if retcode != 0: raise Exception('jrclust returned a non-zero exit code') # parse output result_fname = tmpdir + '/firings.mda' if not os.path.exists(result_fname): raise Exception('Result file does not exist: ' + result_fname) firings = mdaio.readmda(result_fname) sorting = se.NumpySortingExtractor() sorting.set_times_labels(firings[1, :], firings[2, :]) return sorting
def _create_example(self): channel_ids = [0, 1, 2, 3] num_channels = 4 num_frames = 10000 sampling_frequency = 30000 X = np.random.normal(0, 1, (num_channels, num_frames)) geom = np.random.normal(0, 1, (num_channels, 2)) X = (X * 100).astype(int) RX = se.NumpyRecordingExtractor(timeseries=X, sampling_frequency=sampling_frequency, geom=geom) RX2 = se.NumpyRecordingExtractor(timeseries=X, sampling_frequency=sampling_frequency, geom=geom) RX3 = se.NumpyRecordingExtractor(timeseries=X, sampling_frequency=sampling_frequency, geom=geom) SX = se.NumpySortingExtractor() spike_times = [200, 300, 400] train1 = np.sort( np.rint(np.random.uniform(0, num_frames, spike_times[0])).astype(int)) SX.add_unit(unit_id=1, times=train1) SX.add_unit(unit_id=2, times=np.sort( np.random.uniform(0, num_frames, spike_times[1]))) SX.add_unit(unit_id=3, times=np.sort( np.random.uniform(0, num_frames, spike_times[2]))) SX.set_unit_property(unit_id=1, property_name='stability', value=80) SX.set_sampling_frequency(sampling_frequency) SX2 = se.NumpySortingExtractor() spike_times2 = [100, 150, 450] train2 = np.rint(np.random.uniform(0, num_frames, spike_times2[0])).astype(int) SX2.add_unit(unit_id=3, times=train2) SX2.add_unit(unit_id=4, times=np.random.uniform(0, num_frames, spike_times2[1])) SX2.add_unit(unit_id=5, times=np.random.uniform(0, num_frames, spike_times2[2])) SX2.set_unit_property(unit_id=4, property_name='stability', value=80) SX2.set_unit_spike_features(unit_id=3, feature_name='widths', value=np.asarray([3] * spike_times2[0])) RX.set_channel_property(channel_id=0, property_name='location', value=(0, 0)) for i, unit_id in enumerate(SX2.get_unit_ids()): SX2.set_unit_property(unit_id=unit_id, property_name='shared_unit_prop', value=i) SX2.set_unit_spike_features(unit_id=unit_id, feature_name='shared_unit_feature', value=np.asarray([i] * spike_times2[i])) for i, channel_id in enumerate(RX.get_channel_ids()): RX.set_channel_property(channel_id=channel_id, property_name='shared_channel_prop', value=i) example_info = dict(channel_ids=channel_ids, num_channels=num_channels, num_frames=num_frames, sampling_frequency=sampling_frequency, unit_ids=[1, 2, 3], train1=train1, unit_prop=80, channel_prop=(0, 0)) return (RX, RX2, RX3, SX, SX2, example_info)
def ironclust(*, recording, # Recording object tmpdir, # Temporary working directory detect_sign=-1, # Polarity of the spikes, -1, 0, or 1 adjacency_radius=-1, # Channel neighborhood adjacency radius corresponding to geom file detect_threshold=5, # Threshold for detection merge_thresh=.98, # Cluster merging threhold 0..1 freq_min=300, # Lower frequency limit for band-pass filter freq_max=6000, # Upper frequency limit for band-pass filter pc_per_chan=3, # Number of pc per channel prm_template_name, # Name of the template file ironclust_src=None ): if ironclust_src is None: ironclust_src=os.getenv('IRONCLUST_SRC',None) if not ironclust_src: raise Exception('You must either set the IRONCLUST_SRC environment variable, or pass the ironclust_src parameter') source_dir=os.path.dirname(os.path.realpath(__file__)) dataset_dir=tmpdir+'/ironclust_dataset' # Generate three files in the dataset directory: raw.mda, geom.csv, params.json si.MdaRecordingExtractor.writeRecording(recording=recording,save_path=dataset_dir) samplerate=recording.getSamplingFrequency() print('Reading timeseries header...') HH=mdaio.readmda_header(dataset_dir+'/raw.mda') num_channels=HH.dims[0] num_timepoints=HH.dims[1] duration_minutes=num_timepoints/samplerate/60 print('Num. channels = {}, Num. timepoints = {}, duration = {} minutes'.format(num_channels,num_timepoints,duration_minutes)) print('Creating .params file...') txt='' txt+='samplerate={}\n'.format(samplerate) txt+='detect_sign={}\n'.format(detect_sign) txt+='adjacency_radius={}\n'.format(adjacency_radius) txt+='detect_threshold={}\n'.format(detect_threshold) txt+='merge_thresh={}\n'.format(merge_thresh) txt+='freq_min={}\n'.format(freq_min) txt+='freq_max={}\n'.format(freq_max) txt+='pc_per_chan={}\n'.format(pc_per_chan) txt+='prm_template_name={}\n'.format(prm_template_name) _write_text_file(dataset_dir+'/argfile.txt',txt) print('Running IronClust...') cmd_path = "addpath('{}', '{}/matlab', '{}/mdaio');".format(ironclust_src, ironclust_src, ironclust_src) #"p_ironclust('$(tempdir)','$timeseries$','$geom$','$prm$','$firings_true$','$firings_out$','$(argfile)');" cmd_call = "p_ironclust('{}', '{}', '{}', '', '', '{}', '{}');"\ .format(tmpdir, dataset_dir+'/raw.mda', dataset_dir+'/geom.csv', tmpdir+'/firings.mda', dataset_dir+'/argfile.txt') cmd='matlab -nosplash -nodisplay -r "{} {} quit;"'.format(cmd_path, cmd_call) print(cmd) retcode=_run_command_and_print_output(cmd) if retcode != 0: raise Exception('IronClust returned a non-zero exit code') # parse output result_fname=tmpdir+'/firings.mda' if not os.path.exists(result_fname): raise Exception('Result file does not exist: '+ result_fname) firings=mdaio.readmda(result_fname) sorting=si.NumpySortingExtractor() sorting.setTimesLabels(firings[1,:],firings[2,:]) return sorting
def create_signal_with_known_waveforms(n_channels=4, n_waveforms=2, n_wf_samples=100, duration=5, fs=30000): ''' Creates stereotyped recording, sorting, with waveforms, templates, and max_chans ''' a_min = [-200, -50] a_max = [10, 50] wfs = [] # gen waveforms for w in range(n_waveforms): amp_min = np.random.randint(a_min[0], a_min[1]) amp_max = np.random.randint(a_max[0], a_max[1]) wf = create_wf(amp_min, amp_max, n_wf_samples) wfs.append(wf) # gen templates templates = [] max_chans = [] for wf in wfs: found = False while not found: template, amps, found = generate_template_with_random_amps( n_channels, wf) templates.append(template) max_chans.append(np.argmax(amps)) templates = np.array(templates) n_samples = int(fs * duration) # gen spiketrains interval = 10 * n_wf_samples times = np.arange(interval, duration * fs - interval, interval).astype(int) labels = np.zeros(len(times)).astype(int) for i, wf in enumerate(wfs): labels[i::len(wfs)] = i timeseries = np.zeros((n_channels, n_samples)) waveforms = [] amplitudes = [] for i, tem in enumerate(templates): idxs = np.where(labels == i) wav = [] amps = [] for t in times[idxs]: rand_val = np.random.randn() * 0.01 + 1 timeseries[:, t - n_wf_samples // 2:t + n_wf_samples // 2] = rand_val * tem wav.append(rand_val * tem) amps.append(np.min(rand_val * tem)) wav = np.array(wav) amps = np.array(amps) waveforms.append(wav) amplitudes.append(amps) rec = se.NumpyRecordingExtractor(timeseries=timeseries, sampling_frequency=fs) sort = se.NumpySortingExtractor() sort.set_times_labels(times=times, labels=labels) sort.set_sampling_frequency(fs) return rec, sort, waveforms, templates, max_chans, amplitudes
def kilosort2_helper( *, recording, # Recording object tmpdir, # Temporary working directory detect_sign=-1, # Polarity of the spikes, -1, 0, or 1 adjacency_radius=-1, # Channel neighborhood adjacency radius corresponding to geom file detect_threshold=6, # Threshold for detection merge_thresh=.98, # Cluster merging threhold 0..1 freq_min=150, # Lower frequency limit for band-pass filter freq_max=6000, # Upper frequency limit for band-pass filter pc_per_chan=3, # number of PC per chan minFR=1 / 50): # # TODO: do not require ks2 to depend on irc -- rather, put all necessary .m code in the spikeforest repo # ironclust_path = os.environ.get('IRONCLUST_PATH_DEV', None) # if ironclust_path: # print('Using ironclust from IRONCLUST_PATH_DEV directory: {}'.format(ironclust_path)) # else: # try: # print('Auto-installing ironclust.') # ironclust_path = install_ironclust(commit='042b600b014de13f6d11d3b4e50e849caafb4709') # except: # traceback.print_exc() # raise Exception('Problem installing ironclust. You can set the IRONCLUST_PATH_DEV to force to use a particular path.') # print('For kilosort2, using ironclust utility functions from: {}'.format(ironclust_path)) kilosort2_path = os.environ.get('KILOSORT2_PATH_DEV', None) if kilosort2_path: print('Using kilosort2 from KILOSORT2_PATH_DEV directory: {}'.format( kilosort2_path)) else: try: print('Auto-installing kilosort2.') kilosort2_path = KiloSort2.install() except: traceback.print_exc() raise Exception( 'Problem installing kilosort2. You can set the KILOSORT2_PATH_DEV to force to use a particular path.' ) print('Using kilosort2 from: {}'.format(kilosort2_path)) source_dir = os.path.dirname(os.path.realpath(__file__)) dataset_dir = tmpdir + '/kilosort2_dataset' # Generate three files in the dataset directory: raw.mda, geom.csv, params.json SFMdaRecordingExtractor.write_recording(recording=recording, save_path=dataset_dir, _preserve_dtype=True) samplerate = recording.get_sampling_frequency() print('Reading timeseries header...') HH = mdaio.readmda_header(dataset_dir + '/raw.mda') num_channels = HH.dims[0] num_timepoints = HH.dims[1] duration_minutes = num_timepoints / samplerate / 60 print('Num. channels = {}, Num. timepoints = {}, duration = {} minutes'. format(num_channels, num_timepoints, duration_minutes)) print('Creating argfile.txt file...') txt = '' txt += 'samplerate={}\n'.format(samplerate) txt += 'detect_sign={}\n'.format(detect_sign) txt += 'adjacency_radius={}\n'.format(adjacency_radius) txt += 'detect_threshold={}\n'.format(detect_threshold) txt += 'merge_thresh={}\n'.format(merge_thresh) txt += 'freq_min={}\n'.format(freq_min) txt += 'freq_max={}\n'.format(freq_max) txt += 'pc_per_chan={}\n'.format(pc_per_chan) txt += 'minFR={}\n'.format(minFR) _write_text_file(dataset_dir + '/argfile.txt', txt) print('Running Kilosort2 in {tmpdir}...'.format(tmpdir=tmpdir)) cmd = ''' addpath('{source_dir}'); addpath('{source_dir}/mdaio') try p_kilosort2('{ksort}', '{tmpdir}', '{raw}', '{geom}', '{firings}', '{arg}'); catch quit(1); end quit(0); ''' cmd = cmd.format(source_dir=source_dir, ksort=kilosort2_path, tmpdir=tmpdir, raw=dataset_dir + '/raw.mda', geom=dataset_dir + '/geom.csv', firings=tmpdir + '/firings.mda', arg=dataset_dir + '/argfile.txt') matlab_cmd = mlpr.ShellScript(cmd, script_path=tmpdir + '/run_kilosort2.m', keep_temp_files=True) matlab_cmd.write() shell_cmd = ''' #!/bin/bash cd {tmpdir} echo '=====================' `date` '=====================' matlab -nosplash -nodisplay -r run_kilosort2 '''.format(tmpdir=tmpdir) shell_cmd = mlpr.ShellScript(shell_cmd, script_path=tmpdir + '/run_kilosort2.sh', keep_temp_files=True) shell_cmd.write(tmpdir + '/run_kilosort2.sh') shell_cmd.start() retcode = shell_cmd.wait() if retcode != 0: raise Exception('kilosort2 returned a non-zero exit code') # parse output result_fname = tmpdir + '/firings.mda' if not os.path.exists(result_fname): raise Exception('Result file does not exist: ' + result_fname) firings = mdaio.readmda(result_fname) sorting = se.NumpySortingExtractor() sorting.set_times_labels(firings[1, :], firings[2, :]) return sorting
def toy_example(duration: float = 10., num_channels: int = 4, sampling_frequency: float = 30000., K: int = 10, dumpable: bool = False, dump_folder: Optional[Union[str, Path]] = None, seed: Optional[int] = None): """ Create toy recording and sorting extractors. Parameters ---------- duration: float Duration in s (default 10) num_channels: int Number of channels (default 4) sampling_frequency: float Sampling frequency (default 30000) K: int Number of units (default 10) dumpable: bool If True, objects are dumped to file and become 'dumpable' dump_folder: str or Path Path to dump folder (if None, 'test' is used seed: int Seed for random initialization Returns ------- recording: RecordingExtractor The output recording extractor. If dumpable is False it's a NumpyRecordingExtractor, otherwise it's an MdaRecordingExtractor sorting: SortingExtractor The output sorting extractor. If dumpable is False it's a NumpyRecordingExtractor, otherwise it's an NpzSortingExtractor """ upsamplefac = 13 waveforms, geom = synthesize_random_waveforms(K=K, M=num_channels, average_peak_amplitude=-100, upsamplefac=upsamplefac, seed=seed) times, labels = synthesize_random_firings( K=K, duration=duration, sampling_frequency=sampling_frequency, seed=seed) labels = labels.astype(np.int64) SX = se.NumpySortingExtractor() SX.set_times_labels(times, labels) X = synthesize_timeseries(sorting=SX, waveforms=waveforms, noise_level=10, sampling_frequency=sampling_frequency, duration=duration, waveform_upsamplefac=upsamplefac, seed=seed) SX.set_sampling_frequency(sampling_frequency) RX = se.NumpyRecordingExtractor(timeseries=X, sampling_frequency=sampling_frequency, geom=geom) RX.is_filtered = True if dumpable: if dump_folder is None: dump_folder = 'toy_example' dump_folder = Path(dump_folder) se.MdaRecordingExtractor.write_recording(RX, dump_folder) RX = se.MdaRecordingExtractor(dump_folder) se.NpzSortingExtractor.write_sorting(SX, dump_folder / 'sorting.npz') SX = se.NpzSortingExtractor(dump_folder / 'sorting.npz') return RX, SX
def test_empty_write(self): sorting_empty = se.NumpySortingExtractor() se.NpzSortingExtractor.write_sorting(sorting_empty, 'test_NpzSortingExtractors_empty.npz')