def _create_example(self, seed): channel_ids = [0, 1, 2, 3] num_channels = 4 num_frames = 10000 sampling_frequency = 30000 X = np.random.RandomState(seed=seed).normal(0, 1, (num_channels, num_frames)) geom = np.random.RandomState(seed=seed).normal(0, 1, (num_channels, 2)) X = (X * 100).astype(int) RX = se.NumpyRecordingExtractor(timeseries=X, sampling_frequency=sampling_frequency, geom=geom) RX2 = se.NumpyRecordingExtractor(timeseries=X, sampling_frequency=sampling_frequency, geom=geom) RX3 = se.NumpyRecordingExtractor(timeseries=X, sampling_frequency=sampling_frequency, geom=geom) SX = se.NumpySortingExtractor() spike_times = [200, 300, 400] train1 = np.sort(np.rint(np.random.RandomState(seed=seed).uniform(0, num_frames, spike_times[0])).astype(int)) SX.add_unit(unit_id=1, times=train1) SX.add_unit(unit_id=2, times=np.sort(np.random.RandomState(seed=seed).uniform(0, num_frames, spike_times[1]))) SX.add_unit(unit_id=3, times=np.sort(np.random.RandomState(seed=seed).uniform(0, num_frames, spike_times[2]))) SX.set_unit_property(unit_id=1, property_name='stability', value=80) SX.set_sampling_frequency(sampling_frequency) SX2 = se.NumpySortingExtractor() spike_times2 = [100, 150, 450] train2 = np.rint(np.random.RandomState(seed=seed).uniform(0, num_frames, spike_times2[0])).astype(int) SX2.add_unit(unit_id=3, times=train2) SX2.add_unit(unit_id=4, times=np.random.RandomState(seed=seed).uniform(0, num_frames, spike_times2[1])) SX2.add_unit(unit_id=5, times=np.random.RandomState(seed=seed).uniform(0, num_frames, spike_times2[2])) SX2.set_unit_property(unit_id=4, property_name='stability', value=80) SX2.set_unit_spike_features(unit_id=3, feature_name='widths', value=np.asarray([3] * spike_times2[0])) RX.set_channel_locations([0, 0], channel_ids=0) for i, unit_id in enumerate(SX2.get_unit_ids()): SX2.set_unit_property(unit_id=unit_id, property_name='shared_unit_prop', value=i) SX2.set_unit_spike_features(unit_id=unit_id, feature_name='shared_unit_feature', value=np.asarray([i] * spike_times2[i])) for i, channel_id in enumerate(RX.get_channel_ids()): RX.set_channel_property(channel_id=channel_id, property_name='shared_channel_prop', value=i) SX3 = se.NumpySortingExtractor() train3 = np.asarray([1, 20, 21, 35, 38, 45, 46, 47]) SX3.add_unit(unit_id=0, times=train3) features3 = np.asarray([0, 5, 10, 15, 20, 25, 30, 35]) features4 = np.asarray([0, 10, 20, 30]) feature4_idx = np.asarray([0, 2, 4, 6]) SX3.set_unit_spike_features(unit_id=0, feature_name='dummy', value=features3) SX3.set_unit_spike_features(unit_id=0, feature_name='dummy2', value=features4, indexes=feature4_idx) example_info = dict( channel_ids=channel_ids, num_channels=num_channels, num_frames=num_frames, sampling_frequency=sampling_frequency, unit_ids=[1, 2, 3], train1=train1, train2=train2, train3=train3, features3=features3, unit_prop=80, channel_prop=(0, 0) ) return (RX, RX2, RX3, SX, SX2, SX3, example_info)
def compute_units_info(*, recording, sorting, channel_ids=[], unit_ids=[]): if (channel_ids) and (len(channel_ids) > 0): recording = si.SubRecordingExtractor(parent_recording=recording, channel_ids=channel_ids) # load into memory print('Loading recording into RAM...') recording = si.NumpyRecordingExtractor( timeseries=recording.get_traces(), samplerate=recording.get_sampling_frequency()) # do filtering print('Filtering...') recording = bandpass_filter(recording=recording, freq_min=300, freq_max=6000) recording = si.NumpyRecordingExtractor( timeseries=recording.get_traces(), samplerate=recording.get_sampling_frequency()) if (not unit_ids) or (len(unit_ids) == 0): unit_ids = sorting.get_unit_ids() print('Computing channel noise levels...') channel_noise_levels = compute_channel_noise_levels(recording=recording) # No longer use subset to compute the templates print('Computing unit templates...') templates = compute_unit_templates(recording=recording, sorting=sorting, unit_ids=unit_ids, max_num=100) print(recording.get_channel_ids()) ret = [] for i, unit_id in enumerate(unit_ids): print('Unit {} of {} (id={})'.format(i + 1, len(unit_ids), unit_id)) template = templates[i] max_p2p_amps_on_channels = np.max(template, axis=1) - np.min(template, axis=1) peak_channel_index = np.argmax(max_p2p_amps_on_channels) peak_channel = recording.get_channel_ids()[peak_channel_index] peak_signal = np.max(np.abs(template[peak_channel_index, :])) info0 = dict() info0['unit_id'] = int(unit_id) info0['snr'] = peak_signal / channel_noise_levels[peak_channel_index] info0['peak_channel'] = int(recording.get_channel_ids()[peak_channel]) train = sorting.get_unit_spike_train(unit_id=unit_id) info0['num_events'] = int(len(train)) info0['firing_rate'] = float( len(train) / (recording.get_num_frames() / recording.get_sampling_frequency())) ret.append(info0) return ret
def _create_example(self): channel_ids = [0, 1, 2, 3] num_channels = 4 num_frames = 10000 samplerate = 30000 X = np.random.normal(0, 1, (num_channels, num_frames)) geom = np.random.normal(0, 1, (num_channels, 2)) X = (X * 100).astype(int) RX = se.NumpyRecordingExtractor(timeseries=X, samplerate=samplerate, geom=geom) RX2 = se.NumpyRecordingExtractor(timeseries=X, samplerate=samplerate, geom=geom) RX3 = se.NumpyRecordingExtractor(timeseries=X, samplerate=samplerate, geom=geom) SX = se.NumpySortingExtractor() spike_times = [200, 300, 400] train1 = np.sort( np.rint(np.random.uniform(0, num_frames, spike_times[0])).astype(int)) SX.add_unit(unit_id=1, times=train1) SX.add_unit(unit_id=2, times=np.sort( np.random.uniform(0, num_frames, spike_times[1]))) SX.add_unit(unit_id=3, times=np.sort( np.random.uniform(0, num_frames, spike_times[2]))) SX.set_unit_property(unit_id=1, property_name='stablility', value=80) SX.set_sampling_frequency(samplerate) SX2 = se.NumpySortingExtractor() spike_times2 = [100, 150, 450] train2 = np.rint(np.random.uniform(0, num_frames, spike_times[0])).astype(int) SX2.add_unit(unit_id=3, times=train2) SX2.add_unit(unit_id=4, times=np.random.uniform(0, num_frames, spike_times2[1])) SX2.add_unit(unit_id=5, times=np.random.uniform(0, num_frames, spike_times2[2])) SX2.set_unit_property(unit_id=4, property_name='stablility', value=80) RX.set_channel_property(channel_id=0, property_name='location', value=(0, 0)) example_info = dict(channel_ids=channel_ids, num_channels=num_channels, num_frames=num_frames, samplerate=samplerate, unit_ids=[1, 2, 3], train1=train1, unit_prop=80, channel_prop=(0, 0)) return (RX, RX2, RX3, SX, SX2, example_info)
def test_remove_bad_channels(): rec, sort = se.example_datasets.toy_example(dump_folder='test', dumpable=True, duration=2, num_channels=4, seed=0) rec_rm = remove_bad_channels(rec, bad_channel_ids=[0]) assert 0 not in rec_rm.get_channel_ids() rec_rm = remove_bad_channels(rec, bad_channel_ids=[1, 2]) assert 1 not in rec_rm.get_channel_ids() and 2 not in rec_rm.get_channel_ids() check_dumping(rec_rm) shutil.rmtree('test') timeseries = np.random.randn(4, 60000) timeseries[1] = 10 * timeseries[1] rec_np = se.NumpyRecordingExtractor(timeseries=timeseries, sampling_frequency=30000) rec_np.set_channel_locations(np.ones((rec_np.get_num_channels(), 2))) se.MdaRecordingExtractor.write_recording(rec_np, 'test') rec = se.MdaRecordingExtractor('test') rec_rm = remove_bad_channels(rec, bad_channel_ids=None, bad_threshold=2) assert 1 not in rec_rm.get_channel_ids() check_dumping(rec_rm) rec_rm = remove_bad_channels(rec, bad_channel_ids=None, bad_threshold=2, seconds=0.1) assert 1 not in rec_rm.get_channel_ids() check_dumping(rec_rm) rec_rm = remove_bad_channels(rec, bad_channel_ids=None, bad_threshold=2, seconds=10) assert 1 not in rec_rm.get_channel_ids() check_dumping(rec_rm) shutil.rmtree('test')
def setUp(self): M = 4 N = 10000 N_ttl = 50 seed = 0 sampling_frequency = 30000 X = np.random.RandomState(seed=seed).normal(0, 1, (M, N)) geom = np.random.RandomState(seed=seed).normal(0, 1, (M, 2)) self._X = X self._geom = geom self._sampling_frequency = sampling_frequency self.RX = se.NumpyRecordingExtractor( timeseries=X, sampling_frequency=sampling_frequency, geom=geom) self._ttl_frames = np.sort(np.random.permutation(N)[:N_ttl]) self.RX.set_ttls(self._ttl_frames) self.SX = se.NumpySortingExtractor() L = 200 self._train1 = np.rint( np.random.RandomState(seed=seed).uniform(0, N, L)).astype(int) self.SX.add_unit(unit_id=1, times=self._train1) self.SX.add_unit(unit_id=2, times=np.random.RandomState(seed=seed).uniform( 0, N, L)) self.SX.add_unit(unit_id=3, times=np.random.RandomState(seed=seed).uniform( 0, N, L))
def gen_synth_datasets(datasets, *, outdir, samplerate=32000): if not os.path.exists(outdir): os.mkdir(outdir) for ds in datasets: ds_name = ds['name'] print(ds_name) spiketrains = gen_spiketrains(duration=ds['duration'], n_exc=ds['n_exc'], n_inh=ds['n_inh'], f_exc=ds['f_exc'], f_inh=ds['f_inh'], min_rate=ds['min_rate'], st_exc=ds['st_exc'], st_inh=ds['st_inh']) OX = NeoSpikeTrainsOutputExtractor(spiketrains=spiketrains, samplerate=samplerate) X, geom = gen_recording(templates=ds['templates'], output_extractor=OX, noise_level=ds['noise_level'], samplerate=samplerate, duration=ds['duration']) IX = si.NumpyRecordingExtractor(timeseries=X, samplerate=samplerate, geom=geom) si.MdaRecordingExtractor.writeRecording(IX, outdir + '/{}'.format(ds_name)) si.MdaSortingExtractor.writeSorting( OX, outdir + '/{}/firings_true.mda'.format(ds_name)) print('Done.')
def toy_example(duration=10, num_channels=4, samplerate=30000.0, K=10, seed=None): upsamplefac = 13 waveforms, geom = synthesize_random_waveforms(K=K, M=num_channels, average_peak_amplitude=-100, upsamplefac=upsamplefac, seed=seed) times, labels = synthesize_random_firings(K=K, duration=duration, samplerate=samplerate, seed=seed) labels = labels.astype(np.int64) SX = se.NumpySortingExtractor() SX.set_times_labels(times, labels) X = synthesize_timeseries(sorting=SX, waveforms=waveforms, noise_level=10, samplerate=samplerate, duration=duration, waveform_upsamplefac=upsamplefac) SX.set_sampling_frequency(samplerate) RX = se.NumpyRecordingExtractor(timeseries=X, samplerate=samplerate, geom=geom) return (RX, SX)
def test_remove_bad_channels(): rec, sort = se.example_datasets.toy_example(duration=10, num_channels=4) rec_rm = remove_bad_channels(rec, bad_channel_ids=[0]) assert 0 not in rec_rm.get_channel_ids() rec_rm = remove_bad_channels(rec, bad_channel_ids=[1, 2]) assert 1 not in rec_rm.get_channel_ids( ) and 2 not in rec_rm.get_channel_ids() timeseries = np.random.randn(4, 60000) timeseries[1] = 3 * timeseries[1] rec_np = se.NumpyRecordingExtractor(timeseries=timeseries, sampling_frequency=30000) rec_rm = remove_bad_channels(rec_np, bad_channel_ids=None, bad_threshold=2) assert 1 not in rec_rm.get_channel_ids() rec_rm = remove_bad_channels(rec_np, bad_channel_ids=None, bad_threshold=2, seconds=0.1) assert 1 not in rec_rm.get_channel_ids() rec_rm = remove_bad_channels(rec_np, bad_channel_ids=None, bad_threshold=2, seconds=10) assert 1 not in rec_rm.get_channel_ids()
def create_simulated_recording(size, num_frames=1000, sampling_frequency=30000, seed=0): #TODO if centered at 0, 0: two channels at pos 0 if even number # channel_pos = [int(coord-(size-1)/2) for coord in range(0, size)] channel_pos = [coord for coord in range(0, size)] geom = [] for k in channel_pos: for j in channel_pos: geom.append([j, k, 0]) geom = np.asarray(geom) channel_ids = np.arange(0, size * size) num_channels = len(channel_ids) X = np.random.RandomState(seed=seed).normal(0, 1, (num_channels, num_frames)) X = (X * 100).astype(int) X, spike_frame_channel_array = add_artificial_spikes(X) RX = se.NumpyRecordingExtractor(timeseries=X, sampling_frequency=sampling_frequency, geom=geom) return geom, RX, spike_frame_channel_array
def __init__(self, file_path, acquisition_name=None): assert HAVE_NWB, "To use the Nwb extractors, install pynwb: \n\n pip install pynwb\n\n" self._path = file_path self._acquisition_name = acquisition_name with NWBHDF5IO(file_path, 'r') as io: nwbfile = io.read() if acquisition_name is None: a_names = list(nwbfile.acquisition.keys()) if len(a_names) > 1: raise Exception('More than one acquisition found. You must specify acquisition_name.') if len(a_names) == 0: raise Exception('No acquisitions found in the .nwb file.') acquisition_name = a_names[0] ts = nwbfile.acquisition[acquisition_name] self._nwb_timeseries = ts M = np.array(ts.data).shape[1] if M != len(ts.electrodes): raise Exception( 'Number of electrodes does not match the shape of the data {}<>{}'.format(M, len(ts.electrodes))) geom = np.zeros((M, 3)) for m in range(M): geom[m, :] = [ts.electrodes[m][1], ts.electrodes[m][2], ts.electrodes[m][3]] if hasattr(ts, 'timestamps') and ts.timestamps: sampling_frequency = 1 / (ts.timestamps[1] - ts.timestamps[0]) # there's probably a better way else: sampling_frequency = ts.rate * 1000 data = np.copy(np.transpose(ts.data)) NRX = se.NumpyRecordingExtractor(timeseries=data, sampling_frequency=sampling_frequency, geom=geom) CopyRecordingExtractor.__init__(self, NRX)
def setUp(self): M = 32 N = 10000 samplerate = 30000 X = np.random.normal(0, 1, (M, N)) self._X = X self._samplerate = samplerate self.RX = se.NumpyRecordingExtractor(timeseries=X, samplerate=samplerate) self.test_dir = tempfile.mkdtemp()
def createSession(self): recording = SFMdaRecordingExtractor( dataset_directory=self._recording_directory, download=False) recording = se.SubRecordingExtractor( parent_recording=recording, start_frame=0, end_frame=10000) recording = se.NumpyRecordingExtractor( timeseries=recording.get_traces(), samplerate=recording.get_sampling_frequency()) W = SFW.TimeseriesWidget(recording=recording) _make_full_browser(W) return W
def setUp(self): M = 32 N = 10000 seed = 0 sampling_frequency = 30000 X = np.random.RandomState(seed=seed).normal(0, 1, (M, N)) self._X = X self._sampling_frequency = sampling_frequency self.RX = se.NumpyRecordingExtractor(timeseries=X, sampling_frequency=sampling_frequency) self.RX.set_channel_locations(np.random.randn(32, 3)) self.test_dir = tempfile.mkdtemp()
def toy_example1(duration=10, num_channels=4, samplerate=30000, K=10, firing_rates=None, noise_level=10): upsamplefac = 13 waveforms, geom = synthesize_random_waveforms(K=K, M=num_channels, average_peak_amplitude=-100, upsamplefac=upsamplefac) times, labels = synthesize_random_firings(K=K, duration=duration, samplerate=samplerate, firing_rates=firing_rates) labels = labels.astype(np.int64) OX = se.NumpySortingExtractor() OX.set_times_labels(times, labels) X = synthesize_timeseries(sorting=OX, waveforms=waveforms, noise_level=noise_level, samplerate=samplerate, duration=duration, waveform_upsamplefac=upsamplefac) IX = se.NumpyRecordingExtractor(timeseries=X, samplerate=samplerate, geom=geom) return (IX, OX)
def test_highpass_filter(): rec, sort = se.example_datasets.toy_example(dump_folder='test', dumpable=True, duration=2, num_channels=4, seed=0) rec_fft = highpass_filter(rec, freq_min=5000, filter_type='fft') assert check_signal_power_signal1_below_signal2( rec_fft.get_traces(), rec.get_traces(), freq_range=[1000, 5000], fs=rec.get_sampling_frequency()) rec_sci = bandpass_filter(rec, freq_min=3000, freq_max=6000, filter_type='butter', order=3) assert check_signal_power_signal1_below_signal2( rec_sci.get_traces(), rec.get_traces(), freq_range=[1000, 3000], fs=rec.get_sampling_frequency()) traces = rec.get_traces().astype('uint16') rec_u = se.NumpyRecordingExtractor( traces, sampling_frequency=rec.get_sampling_frequency()) rec_fu = bandpass_filter(rec_u, freq_min=5000, freq_max=10000, filter_type='fft') assert check_signal_power_signal1_below_signal2( rec_fu.get_traces(), rec_u.get_traces(), freq_range=[1000, 5000], fs=rec.get_sampling_frequency()) assert check_signal_power_signal1_below_signal2( rec_fu.get_traces(), rec_u.get_traces(), freq_range=[10000, 15000], fs=rec.get_sampling_frequency()) assert not str(rec_fu.get_dtype()).startswith('u') check_dumping(rec_fft) shutil.rmtree('test')
def setUp(self): M = 4 N = 10000 samplerate = 30000 X = np.random.normal(0, 1, (M, N)) geom = np.random.normal(0, 1, (M, 2)) self._X = X self._geom = geom self._samplerate = samplerate self.RX = se.NumpyRecordingExtractor(timeseries=X, samplerate=samplerate, geom=geom) self.SX = se.NumpySortingExtractor() L = 200 self._train1 = np.rint(np.random.uniform(0, N, L)).astype(int) self.SX.add_unit(unit_id=1, times=self._train1) self.SX.add_unit(unit_id=2, times=np.random.uniform(0, N, L)) self.SX.add_unit(unit_id=3, times=np.random.uniform(0, N, L))
def toy_example(duration=10, num_channels=4, sampling_frequency=30000.0, K=10, dumpable=False, dump_folder=None, seed=None): upsamplefac = 13 waveforms, geom = synthesize_random_waveforms(K=K, M=num_channels, average_peak_amplitude=-100, upsamplefac=upsamplefac, seed=seed) times, labels = synthesize_random_firings( K=K, duration=duration, sampling_frequency=sampling_frequency, seed=seed) labels = labels.astype(np.int64) SX = se.NumpySortingExtractor() SX.set_times_labels(times, labels) X = synthesize_timeseries(sorting=SX, waveforms=waveforms, noise_level=10, sampling_frequency=sampling_frequency, duration=duration, waveform_upsamplefac=upsamplefac, seed=seed) SX.set_sampling_frequency(sampling_frequency) RX = se.NumpyRecordingExtractor(timeseries=X, sampling_frequency=sampling_frequency, geom=geom) RX.is_filtered = True if dumpable: if dump_folder is None: dump_folder = 'toy_example' dump_folder = Path(dump_folder) se.MdaRecordingExtractor.write_recording(RX, dump_folder) RX = se.MdaRecordingExtractor(dump_folder) se.NpzSortingExtractor.write_sorting(SX, dump_folder / 'sorting.npz') SX = se.NpzSortingExtractor(dump_folder / 'sorting.npz') return RX, SX
def gen_synth_datasets(datasets, *, outdir, num_channels=4, upsamplefac=13, samplerate=30000, average_peak_amplitude=-100): if not os.path.exists(outdir): os.mkdir(outdir) for ds in datasets: ds_name = ds['name'] print(ds_name) K = ds['K'] duration = ds['duration'] noise_level = ds['noise_level'] waveforms, geom = synthesize_random_waveforms( K=K, M=num_channels, average_peak_amplitude=average_peak_amplitude, upsamplefac=upsamplefac) times, labels = synthesize_random_firings(K=K, duration=duration, samplerate=samplerate) labels = labels.astype(np.int64) OX = si.NumpySortingExtractor() OX.setTimesLabels(times, labels) X = synthesize_timeseries(output_extractor=OX, waveforms=waveforms, noise_level=noise_level, samplerate=samplerate, duration=duration, waveform_upsamplefac=upsamplefac) IX = si.NumpyRecordingExtractor(timeseries=X, samplerate=samplerate, geom=geom) si.MdaRecordingExtractor.writeRecording(IX, outdir + '/{}'.format(ds_name)) si.MdaSortingExtractor.writeSorting( OX, outdir + '/{}/firings_true.mda'.format(ds_name)) # write json with two fields print('Done.')
def _create_example(self): channel_ids = [0, 1, 2, 3] num_channels = 4 num_frames = 10000 sampling_frequency = 30000 X = np.random.normal(0, 1, (num_channels, num_frames)) geom = np.random.normal(0, 1, (num_channels, 2)) X = (X * 100).astype(int) RX = se.NumpyRecordingExtractor(timeseries=X, sampling_frequency=sampling_frequency, geom=geom) RX2 = se.NumpyRecordingExtractor(timeseries=X, sampling_frequency=sampling_frequency, geom=geom) RX3 = se.NumpyRecordingExtractor(timeseries=X, sampling_frequency=sampling_frequency, geom=geom) SX = se.NumpySortingExtractor() spike_times = [200, 300, 400] train1 = np.sort( np.rint(np.random.uniform(0, num_frames, spike_times[0])).astype(int)) SX.add_unit(unit_id=1, times=train1) SX.add_unit(unit_id=2, times=np.sort( np.random.uniform(0, num_frames, spike_times[1]))) SX.add_unit(unit_id=3, times=np.sort( np.random.uniform(0, num_frames, spike_times[2]))) SX.set_unit_property(unit_id=1, property_name='stability', value=80) SX.set_sampling_frequency(sampling_frequency) SX2 = se.NumpySortingExtractor() spike_times2 = [100, 150, 450] train2 = np.rint(np.random.uniform(0, num_frames, spike_times2[0])).astype(int) SX2.add_unit(unit_id=3, times=train2) SX2.add_unit(unit_id=4, times=np.random.uniform(0, num_frames, spike_times2[1])) SX2.add_unit(unit_id=5, times=np.random.uniform(0, num_frames, spike_times2[2])) SX2.set_unit_property(unit_id=4, property_name='stability', value=80) SX2.set_unit_spike_features(unit_id=3, feature_name='widths', value=np.asarray([3] * spike_times2[0])) RX.set_channel_property(channel_id=0, property_name='location', value=(0, 0)) for i, unit_id in enumerate(SX2.get_unit_ids()): SX2.set_unit_property(unit_id=unit_id, property_name='shared_unit_prop', value=i) SX2.set_unit_spike_features(unit_id=unit_id, feature_name='shared_unit_feature', value=np.asarray([i] * spike_times2[i])) for i, channel_id in enumerate(RX.get_channel_ids()): RX.set_channel_property(channel_id=channel_id, property_name='shared_channel_prop', value=i) example_info = dict(channel_ids=channel_ids, num_channels=num_channels, num_frames=num_frames, sampling_frequency=sampling_frequency, unit_ids=[1, 2, 3], train1=train1, unit_prop=80, channel_prop=(0, 0)) return (RX, RX2, RX3, SX, SX2, example_info)
def sort_main(task, overwrite_flag=0): try: save_path = Path(task['save_path'], task['task_type']) if (not (save_path / 'recording.dat').exists()) or overwrite_flag: # load task data data = np.load(task['file_path']) with open(task['file_header_path'], 'rb') as f: data_info = pickle.load(f) # prepare filter sos, _ = pp.get_sos_filter_bank(['Sp'], fs=data_info['fs']) spk_data = np.zeros_like(data) assert data_info['n_chans'] == spk_data.shape[0], "Inconsistent formating in the data files. Aborting." # spk filter (high pass) t0 = time.time() for ch in range(data_info['n_chans']): spk_data[ch] = scipy.signal.sosfiltfilt(sos, data[ch]) print('', end='.') t1 = time.time() print('\nTime to spk filter data {0:0.2f}s'.format(t1 - t0)) chan_masks = pp.create_chan_masks(data_info['Raw']['ClippedSegs'], data_info['n_samps']) chan_mad = pp.get_signals_mad(spk_data, chan_masks) data_info['Spk'] = {'mad': chan_mad} # convert data to spikeinterface format spk_data_masked = se.NumpyRecordingExtractor(timeseries=spk_data * chan_masks, geom=data_info['tt_geom'], sampling_frequency=data_info['fs']) # sort data sort = sort_data(spk_data_masked, save_path, sorter=task['task_type']) if sort is not None: # export data to phy st.postprocessing.export_to_phy(recording=spk_data_masked, sorting=sort, output_folder=str(save_path), compute_pc_features=False, compute_amplitudes=False, max_channels_per_template=4) # get cluster stats spk_times_list = sort.get_units_spike_train() cluster_stats = get_cluster_stats(spk_times_list, spk_data_masked.get_traces(), data_info) cluster_stats_file_path = Path(save_path, 'cluster_stats.csv') cluster_stats.to_csv(cluster_stats_file_path) print('downSuccessful sort.') else: print('Uncesseful sort.') # save header updated_file_header_path = Path(task['save_path'], Path(task['file_header_path']).name) with updated_file_header_path.open(mode='wb') as file_handle: pickle.dump(data_info, file_handle, protocol=pickle.HIGHEST_PROTOCOL) else: print('Sorting Done and overwrite flag is False, skipping this sort.') except KeyboardInterrupt: print('Keyboard Interrupt Detected. Aborting Task Processing.') sys.exit() except: print("Error", sys.exc_info()[0], sys.exc_info()[1], sys.exc_info()[2].tb_lineno) traceback.print_exc(file=sys.stdout)
def run(self): # This temporary file will automatically be removed even in the case of a python exception with TemporaryDirectory() as tmpdir: # names of files for the temporary/intermediate data filt = tmpdir + '/filt.mda' filt2 = tmpdir + '/filt2.mda' pre = tmpdir + '/pre.mda' print('Bandpass filtering raw -> filt...') _bandpass_filter(self.recording_file_in, filt) if self.mask_out_artifacts: print('Masking out artifacts filt -> filt2...') _mask_out_artifacts(filt, filt2) else: print('Copying filt -> filt2...') filt2 = filt if self.whiten: print('Whitening filt2 -> pre...') _whiten(filt2, pre) else: pre = filt2 # read the preprocessed timeseries into RAM (maybe we'll do it differently later) X = sf.mdaio.readmda(pre) # handle the geom if type(self.geom_in) == str: print('Using geom.csv from a file', self.geom_in) geom = read_geom_csv(self.geom_in) else: # no geom file was provided as input num_channels = X.shape[0] if num_channels > 6: raise Exception( 'For more than six channels, we require that a geom.csv be provided') # otherwise make a trivial geometry file print('Making a trivial geom file.') geom = np.zeros((X.shape[0], 2)) # Now represent the preprocessed recording using a RecordingExtractor recording = se.NumpyRecordingExtractor( X, samplerate=30000, geom=geom) # hard-code this for now -- idea: run many simultaneous jobs, each using only 2 cores # important to set certain environment variables in the .sh script that calls this .py script num_workers = 2 # Call MountainSort4 sorting = ml_ms4alg.mountainsort4( recording=recording, detect_sign=self.detect_sign, adjacency_radius=self.adjacency_radius, clip_size=self.clip_size, detect_threshold=self.detect_threshold, detect_interval=self.detect_interval, num_workers=num_workers, ) # Write the firings.mda print('Writing firings.mda...') sf.SFMdaSortingExtractor.write_sorting( sorting=sorting, save_path=self.firings_out) print('Computing cluster metrics...') cluster_metrics_path = tmpdir + '/cluster_metrics.json' _cluster_metrics(pre, self.firings_out, cluster_metrics_path) print('Computing isolation metrics...') isolation_metrics_path = tmpdir + '/isolation_metrics.json' pair_metrics_path = tmpdir + '/pair_metrics.json' _isolation_metrics(pre, self.firings_out, isolation_metrics_path, pair_metrics_path) print('Combining metrics...') metrics_path = tmpdir + '/metrics.json' _combine_metrics(cluster_metrics_path, isolation_metrics_path, metrics_path) # copy metrics.json to the output location shutil.copy(metrics_path, self.metrics_out) print('Creating label map...') label_map_path = tmpdir + '/label_map.mda' create_label_map(metrics=metrics_path, label_map_out=label_map_path) print('Applying label map...') apply_label_map(firings=self.firings_out, label_map=label_map_path, firings_out=self.firings_curated_out)
def create_signal_with_known_waveforms(n_channels=4, n_waveforms=2, n_wf_samples=100, duration=5, fs=30000): ''' Creates stereotyped recording, sorting, with waveforms, templates, and max_chans ''' a_min = [-200, -50] a_max = [10, 50] wfs = [] # gen waveforms for w in range(n_waveforms): amp_min = np.random.randint(a_min[0], a_min[1]) amp_max = np.random.randint(a_max[0], a_max[1]) wf = create_wf(amp_min, amp_max, n_wf_samples) wfs.append(wf) # gen templates templates = [] max_chans = [] for wf in wfs: found = False while not found: template, amps, found = generate_template_with_random_amps( n_channels, wf) templates.append(template) max_chans.append(np.argmax(amps)) templates = np.array(templates) n_samples = int(fs * duration) # gen spiketrains interval = 10 * n_wf_samples times = np.arange(interval, duration * fs - interval, interval).astype(int) labels = np.zeros(len(times)).astype(int) for i, wf in enumerate(wfs): labels[i::len(wfs)] = i timeseries = np.zeros((n_channels, n_samples)) waveforms = [] amplitudes = [] for i, tem in enumerate(templates): idxs = np.where(labels == i) wav = [] amps = [] for t in times[idxs]: rand_val = np.random.randn() * 0.01 + 1 timeseries[:, t - n_wf_samples // 2:t + n_wf_samples // 2] = rand_val * tem wav.append(rand_val * tem) amps.append(np.min(rand_val * tem)) wav = np.array(wav) amps = np.array(amps) waveforms.append(wav) amplitudes.append(amps) rec = se.NumpyRecordingExtractor(timeseries=timeseries, sampling_frequency=fs) sort = se.NumpySortingExtractor() sort.set_times_labels(times=times, labels=labels) sort.set_sampling_frequency(fs) return rec, sort, waveforms, templates, max_chans, amplitudes
def toy_example(duration: float = 10., num_channels: int = 4, sampling_frequency: float = 30000., K: int = 10, dumpable: bool = False, dump_folder: Optional[Union[str, Path]] = None, seed: Optional[int] = None): """ Create toy recording and sorting extractors. Parameters ---------- duration: float Duration in s (default 10) num_channels: int Number of channels (default 4) sampling_frequency: float Sampling frequency (default 30000) K: int Number of units (default 10) dumpable: bool If True, objects are dumped to file and become 'dumpable' dump_folder: str or Path Path to dump folder (if None, 'test' is used seed: int Seed for random initialization Returns ------- recording: RecordingExtractor The output recording extractor. If dumpable is False it's a NumpyRecordingExtractor, otherwise it's an MdaRecordingExtractor sorting: SortingExtractor The output sorting extractor. If dumpable is False it's a NumpyRecordingExtractor, otherwise it's an NpzSortingExtractor """ upsamplefac = 13 waveforms, geom = synthesize_random_waveforms(K=K, M=num_channels, average_peak_amplitude=-100, upsamplefac=upsamplefac, seed=seed) times, labels = synthesize_random_firings( K=K, duration=duration, sampling_frequency=sampling_frequency, seed=seed) labels = labels.astype(np.int64) SX = se.NumpySortingExtractor() SX.set_times_labels(times, labels) X = synthesize_timeseries(sorting=SX, waveforms=waveforms, noise_level=10, sampling_frequency=sampling_frequency, duration=duration, waveform_upsamplefac=upsamplefac, seed=seed) SX.set_sampling_frequency(sampling_frequency) RX = se.NumpyRecordingExtractor(timeseries=X, sampling_frequency=sampling_frequency, geom=geom) RX.is_filtered = True if dumpable: if dump_folder is None: dump_folder = 'toy_example' dump_folder = Path(dump_folder) se.MdaRecordingExtractor.write_recording(RX, dump_folder) RX = se.MdaRecordingExtractor(dump_folder) se.NpzSortingExtractor.write_sorting(SX, dump_folder / 'sorting.npz') SX = se.NpzSortingExtractor(dump_folder / 'sorting.npz') return RX, SX
def test_accuracy_of_denoising(): # Test the accuracy of denoising duration=10 num_channels=4 sampling_frequency=30000 K=10 seed=None upsamplefac = 13 waveforms, geom = example_datasets.synthesize_random_waveforms(K=K, M=num_channels, average_peak_amplitude=-100, upsamplefac=upsamplefac, seed=seed) times, labels = example_datasets.synthesize_random_firings( K=K, duration=duration, sampling_frequency=sampling_frequency, seed=seed) labels = labels.astype(np.int64) SX = se.NumpySortingExtractor() SX.set_times_labels(times, labels) SX.set_sampling_frequency(sampling_frequency) recordings = [] for noise_level in [0, 10]: X = example_datasets.synthesize_timeseries( sorting=SX, waveforms=waveforms, noise_level=noise_level, sampling_frequency=sampling_frequency, duration=duration, waveform_upsamplefac=upsamplefac, seed=seed ) RX = se.NumpyRecordingExtractor( timeseries=X, sampling_frequency=sampling_frequency, geom=geom) recordings.append(RX) recording_without_noise = recordings[0] recording_with_noise = recordings[1] opts = ephys_nlm_v1_opts( multi_neighborhood=False, block_size_sec=30, clip_size=30, sigma='auto', sigma_scale_factor=1, whitening='auto', whitening_pctvar=90, denom_threshold=30 ) recording_denoised, runtime_info = ephys_nlm_v1( recording=recording_with_noise, opts=opts, device=None, # detect from the EPHYS_NLM_DEVICE environment variable verbose=2 ) traces_with_noise = recording_with_noise.get_traces() traces_without_noise = recording_without_noise.get_traces() traces_denoised = recording_denoised.get_traces() std_noise_before = np.sqrt(np.var(traces_without_noise - traces_with_noise)) std_noise_after = np.sqrt(np.var(traces_without_noise - traces_denoised)) print(f'std_noise_before = {std_noise_before}; std_noise_after = {std_noise_after};') assert std_noise_after < 0.3 * std_noise_before
def run_conversion(self, nwbfile: NWBFile, metadata: dict, stub_test: bool = False, write_ecephys_metadata: bool = False): if 'UnitProperties' not in metadata: metadata['UnitProperties'] = [] if write_ecephys_metadata and 'Ecephys' in metadata: n_channels = max( [len(x['data']) for x in metadata['Ecephys']['Electrodes']]) recording = se.NumpyRecordingExtractor(timeseries=np.array( range(n_channels)), sampling_frequency=1) se.NwbRecordingExtractor.add_devices(recording=recording, nwbfile=nwbfile, metadata=metadata) se.NwbRecordingExtractor.add_electrode_groups(recording=recording, nwbfile=nwbfile, metadata=metadata) se.NwbRecordingExtractor.add_electrodes(recording=recording, nwbfile=nwbfile, metadata=metadata) property_descriptions = dict() if stub_test: max_min_spike_time = max([ min(x) for y in self.sorting_extractor.get_unit_ids() for x in [self.sorting_extractor.get_unit_spike_train(y)] if any(x) ]) stub_sorting_extractor = se.SubSortingExtractor( self.sorting_extractor, unit_ids=self.sorting_extractor.get_unit_ids(), start_frame=0, end_frame=1.1 * max_min_spike_time) sorting_extractor = stub_sorting_extractor else: sorting_extractor = self.sorting_extractor for metadata_column in metadata['UnitProperties']: assert len(metadata_column['data']) == len(sorting_extractor.get_unit_ids()), \ f"The metadata_column '{metadata_column['name']}' data must have the same dimension as the sorting IDs!" property_descriptions.update( {metadata_column['name']: metadata_column['description']}) for unit_idx, unit_id in enumerate( sorting_extractor.get_unit_ids()): if metadata_column['name'] == 'electrode_group': if nwbfile.electrode_groups: data = nwbfile.electrode_groups[metadata_column['data'] [unit_idx]] sorting_extractor.set_unit_property( unit_id, metadata_column['name'], data) else: data = metadata_column['data'][unit_idx] sorting_extractor.set_unit_property( unit_id, metadata_column['name'], data) se.NwbSortingExtractor.write_sorting( sorting_extractor, property_descriptions=property_descriptions, nwbfile=nwbfile)
def yuta2nwb( session_path='D:/BuzsakiData/SenzaiY/YutaMouse41/YutaMouse41-150903', # '/Users/bendichter/Desktop/Buzsaki/SenzaiBuzsaki2017/YutaMouse41/YutaMouse41-150903', subject_xls=None, include_spike_waveforms=True, stub=True, cache_spec=True): subject_path, session_id = os.path.split(session_path) fpath_base = os.path.split(subject_path)[0] identifier = session_id mouse_number = session_id[9:11] if '-' in session_id: subject_id, date_text = session_id.split('-') b = False else: subject_id, date_text = session_id.split('b') b = True if subject_xls is None: subject_xls = os.path.join(subject_path, 'YM' + mouse_number + ' exp_sheet.xlsx') else: if not subject_xls[-4:] == 'xlsx': subject_xls = os.path.join(subject_xls, 'YM' + mouse_number + ' exp_sheet.xlsx') session_start_time = dateparse(date_text, yearfirst=True) df = pd.read_excel(subject_xls) subject_data = {} for key in [ 'genotype', 'DOB', 'implantation', 'Probe', 'Surgery', 'virus injection', 'mouseID' ]: names = df.iloc[:, 0] if key in names.values: subject_data[key] = df.iloc[np.argmax(names == key), 1] if isinstance(subject_data['DOB'], datetime): age = session_start_time - subject_data['DOB'] else: age = None subject = Subject(subject_id=subject_id, age=str(age), genotype=subject_data['genotype'], species='mouse') nwbfile = NWBFile( session_description='mouse in open exploration and theta maze', identifier=identifier, session_start_time=session_start_time.astimezone(), file_create_date=datetime.now().astimezone(), experimenter='Yuta Senzai', session_id=session_id, institution='NYU', lab='Buzsaki', subject=subject, related_publications='DOI:10.1016/j.neuron.2016.12.011') print('reading and writing raw position data...', end='', flush=True) ns.add_position_data(nwbfile, session_path) shank_channels = ns.get_shank_channels(session_path)[:8] nshanks = len(shank_channels) all_shank_channels = np.concatenate(shank_channels) print('setting up electrodes...', end='', flush=True) hilus_csv_path = os.path.join(fpath_base, 'early_session_hilus_chans.csv') lfp_channel = get_reference_elec(subject_xls, hilus_csv_path, session_start_time, session_id, b=b) custom_column = [{ 'name': 'theta_reference', 'description': 'this electrode was used to calculate LFP canonical bands', 'data': all_shank_channels == lfp_channel }] ns.write_electrode_table(nwbfile, session_path, custom_columns=custom_column, max_shanks=max_shanks) print('reading raw electrode data...', end='', flush=True) if stub: # example recording extractor for fast testing xml_filepath = os.path.join(session_path, session_id + '.xml') xml_root = et.parse(xml_filepath).getroot() acq_sampling_frequency = float( xml_root.find('acquisitionSystem').find('samplingRate').text) num_channels = 4 num_frames = 10000 X = np.random.normal(0, 1, (num_channels, num_frames)) geom = np.random.normal(0, 1, (num_channels, 2)) X = (X * 100).astype(int) sre = se.NumpyRecordingExtractor( timeseries=X, sampling_frequency=acq_sampling_frequency, geom=geom) else: nre = se.NeuroscopeRecordingExtractor('{}/{}.dat'.format( session_path, session_id)) sre = se.SubRecordingExtractor(nre, channel_ids=all_shank_channels) print('writing raw electrode data...', end='', flush=True) se.NwbRecordingExtractor.add_electrical_series(sre, nwbfile) print('done.') print('reading spiking units...', end='', flush=True) if stub: spike_times = [200, 300, 400] num_frames = 10000 allshanks = [] for k in range(nshanks): SX = se.NumpySortingExtractor() for j in range(len(spike_times)): SX.add_unit(unit_id=j + 1, times=np.sort( np.random.uniform(0, num_frames, spike_times[j]))) allshanks.append(SX) se_allshanks = se.MultiSortingExtractor(allshanks) se_allshanks.set_sampling_frequency(acq_sampling_frequency) else: se_allshanks = se.NeuroscopeMultiSortingExtractor(session_path, keep_mua_units=False) electrode_group = [] for shankn in np.arange(1, nshanks + 1, dtype=int): for id in se_allshanks.sortings[shankn - 1].get_unit_ids(): electrode_group.append(nwbfile.electrode_groups['shank' + str(shankn)]) df_unit_features = get_UnitFeatureCell_features(fpath_base, session_id, session_path) celltype_names = [] for celltype_id, region_id in zip(df_unit_features['fineCellType'].values, df_unit_features['region'].values): if celltype_id == 1: if region_id == 3: celltype_names.append('pyramidal cell') elif region_id == 4: celltype_names.append('granule cell') else: raise Exception('unknown type') elif not np.isfinite(celltype_id): celltype_names.append('missing') else: celltype_names.append(celltype_dict[celltype_id]) # Add custom column data into the SortingExtractor so it can be written by the converter # Note there is currently a hidden assumption that the way in which the NeuroscopeSortingExtractor # merges the cluster IDs matches one-to-one with the get_UnitFeatureCell_features extraction property_descriptions = { 'cell_type': 'name of cell type', 'global_id': 'global id for cell for entire experiment', 'shank_id': '0-indexed id of cluster of shank', 'electrode_group': 'the electrode group that each spike unit came from' } property_values = { 'cell_type': celltype_names, 'global_id': df_unit_features['unitID'].values, 'shank_id': [x - 2 for x in df_unit_features['unitIDshank'].values], # - 2 b/c the get_UnitFeatureCell_features removes 0 and 1 IDs from each shank 'electrode_group': electrode_group } for unit_id in se_allshanks.get_unit_ids(): for property_name in property_descriptions.keys(): se_allshanks.set_unit_property( unit_id, property_name, property_values[property_name][unit_id]) se.NwbSortingExtractor.write_sorting( se_allshanks, nwbfile=nwbfile, property_descriptions=property_descriptions) print('done.') # Read and write LFP's print('reading LFPs...', end='', flush=True) lfp_fs, all_channels_lfp_data = ns.read_lfp(session_path, stub=stub) lfp_data = all_channels_lfp_data[:, all_shank_channels] print('writing LFPs...', flush=True) # lfp_data[:int(len(lfp_data)/4)] lfp_ts = ns.write_lfp(nwbfile, lfp_data, lfp_fs, name='lfp', description='lfp signal for all shank electrodes') # Read and add special environmental electrodes for name, channel in special_electrode_dict.items(): ts = TimeSeries( name=name, description= 'environmental electrode recorded inline with neural data', data=all_channels_lfp_data[:, channel], rate=lfp_fs, unit='V', #conversion=np.nan, resolution=np.nan) nwbfile.add_acquisition(ts) # compute filtered LFP print('filtering LFP...', end='', flush=True) all_lfp_phases = [] for passband in ('theta', 'gamma'): lfp_fft = filter_lfp( lfp_data[:, all_shank_channels == lfp_channel].ravel(), lfp_fs, passband=passband) lfp_phase, _ = hilbert_lfp(lfp_fft) all_lfp_phases.append(lfp_phase[:, np.newaxis]) data = np.dstack(all_lfp_phases) print('done.', flush=True) if include_spike_waveforms: print('writing waveforms...', end='', flush=True) nshanks = min((max_shanks, len(ns.get_shank_channels(session_path)))) for shankn in np.arange(nshanks, dtype=int) + 1: # Get spike activty from .spk file on a per-shank and per-sample basis ns.write_spike_waveforms(nwbfile, session_path, shankn, stub=stub) print('done.', flush=True) # Get the LFP Decomposition Series decomp_series = DecompositionSeries( name='LFPDecompositionSeries', description='Theta and Gamma phase for reference LFP', data=data, rate=lfp_fs, source_timeseries=lfp_ts, metric='phase', unit='radians') decomp_series.add_band(band_name='theta', band_limits=(4, 10)) decomp_series.add_band(band_name='gamma', band_limits=(30, 80)) check_module(nwbfile, 'ecephys', 'contains processed extracellular electrophysiology data' ).add_data_interface(decomp_series) [nwbfile.add_stimulus(x) for x in ns.get_events(session_path)] # create epochs corresponding to experiments/environments for the mouse sleep_state_fpath = os.path.join(session_path, '{}--StatePeriod.mat'.format(session_id)) exist_pos_data = any( os.path.isfile( os.path.join(session_path, '{}__{}.mat'.format( session_id, task_type['name']))) for task_type in task_types) if exist_pos_data: nwbfile.add_epoch_column('label', 'name of epoch') for task_type in task_types: label = task_type['name'] file = os.path.join(session_path, session_id + '__' + label + '.mat') if os.path.isfile(file): print('loading position for ' + label + '...', end='', flush=True) pos_obj = Position(name=label + '_position') matin = loadmat(file) tt = matin['twhl_norm'][:, 0] exp_times = find_discontinuities(tt) if 'conversion' in task_type: conversion = task_type['conversion'] else: conversion = np.nan for pos_type in ('twhl_norm', 'twhl_linearized'): if pos_type in matin: pos_data_norm = matin[pos_type][:, 1:] spatial_series_object = SpatialSeries( name=label + '_{}_spatial_series'.format(pos_type), data=H5DataIO(pos_data_norm, compression='gzip'), reference_frame='unknown', conversion=conversion, resolution=np.nan, timestamps=H5DataIO(tt, compression='gzip')) pos_obj.add_spatial_series(spatial_series_object) check_module( nwbfile, 'behavior', 'contains processed behavioral data').add_data_interface( pos_obj) for i, window in enumerate(exp_times): nwbfile.add_epoch(start_time=window[0], stop_time=window[1], label=label + '_' + str(i)) print('done.') # there are occasional mismatches between the matlab struct and the neuroscope files # regions: 3: 'CA3', 4: 'DG' trialdata_path = os.path.join(session_path, session_id + '__EightMazeRun.mat') if os.path.isfile(trialdata_path): trials_data = loadmat(trialdata_path)['EightMazeRun'] trialdatainfo_path = os.path.join(fpath_base, 'EightMazeRunInfo.mat') trialdatainfo = [ x[0] for x in loadmat(trialdatainfo_path)['EightMazeRunInfo'][0] ] features = trialdatainfo[:7] features[:2] = 'start_time', 'stop_time', [ nwbfile.add_trial_column(x, 'description') for x in features[4:] + ['condition'] ] for trial_data in trials_data: if trial_data[3]: cond = 'run_left' else: cond = 'run_right' nwbfile.add_trial(start_time=trial_data[0], stop_time=trial_data[1], condition=cond, error_run=trial_data[4], stim_run=trial_data[5], both_visit=trial_data[6]) """ mono_syn_fpath = os.path.join(session_path, session_id+'-MonoSynConvClick.mat') matin = loadmat(mono_syn_fpath) exc = matin['FinalExcMonoSynID'] inh = matin['FinalInhMonoSynID'] #exc_obj = CatCellInfo(name='excitatory_connections', # indices_values=[], cell_index=exc[:, 0] - 1, indices=exc[:, 1] - 1) #module_cellular.add_container(exc_obj) #inh_obj = CatCellInfo(name='inhibitory_connections', # indices_values=[], cell_index=inh[:, 0] - 1, indices=inh[:, 1] - 1) #module_cellular.add_container(inh_obj) """ if os.path.isfile(sleep_state_fpath): matin = loadmat(sleep_state_fpath)['StatePeriod'] table = TimeIntervals(name='states', description='sleep states of animal') table.add_column(name='label', description='sleep state') data = [] for name in matin.dtype.names: for row in matin[name][0][0]: data.append({ 'start_time': row[0], 'stop_time': row[1], 'label': name }) [ table.add_row(**row) for row in sorted(data, key=lambda x: x['start_time']) ] check_module(nwbfile, 'behavior', 'contains behavioral data').add_data_interface(table) print('writing NWB file...', end='', flush=True) if stub: out_fname = session_path + '_stub.nwb' else: out_fname = session_path + '.nwb' with NWBHDF5IO(out_fname, mode='w') as io: io.write(nwbfile, cache_spec=cache_spec) print('done.') print('testing read...', end='', flush=True) # test read with NWBHDF5IO(out_fname, mode='r') as io: io.read() print('done.')