def generate_spike_trains(exdir_path, openephys_rec, source='klusta'): import neo if source == 'klusta': # TODO acquire features and masks print('Generating spike trains from KWIK file') exdir_file = exdir.File(exdir_path) acquisition = exdir_file["acquisition"] openephys_session = acquisition.attrs["openephys_session"] klusta_directory = op.join(str(acquisition.directory), openephys_session, 'klusta') n = 0 for root, dirs, files in os.walk(klusta_directory): for f in files: if not f.endswith('_klusta.kwik'): continue n += 1 kwikfile = op.join(root, f) kwikio = neo.io.KwikIO(filename=kwikfile, ) blk = kwikio.read_block(raw_data_units='uV') seg = blk.segments[0] try: exdirio = neo.io.ExdirIO(exdir_path) exdirio.write_block(blk) except Exception: print('WARNING: unable to convert\n', kwikfile) if n == 0: raise IOError('.kwik file cannot be found in ' + klusta_directory) elif source == 'openephys': exdirio = neo.io.ExdirIO(exdir_path) for oe_group in openephys_rec.channel_groups: channel_ids = [ch.id for ch in oe_group.channels] channel_index = [ch.index for ch in oe_group.channels] chx = neo.ChannelIndex(name='channel group {}'.format(oe_group.id), channel_ids=channel_ids, index=channel_index, group_id=oe_group.id) for sptr in oe_group.spiketrains: unit = neo.Unit(cluster_group='unsorted', cluster_id=sptr.attrs['cluster_id'], name=sptr.attrs['name']) unit.spiketrains.append( neo.SpikeTrain(times=sptr.times, waveforms=sptr.waveforms, sampling_rate=sptr.sample_rate, t_stop=sptr.t_stop, **sptr.attrs)) chx.units.append(unit) exdirio.write_channelindex(chx, start_time=0 * pq.s, stop_time=openephys_rec.duration) elif source == 'kilosort': print('Generating spike trains from KiloSort') exdirio = neo.io.ExdirIO(exdir_path) exdir_file = exdir.File(exdir_path) openephys_directory = op.join( str(exdir_file["acquisition"].directory), exdir_file["acquisition"].attrs["openephys_session"]) # iterate over channel groups. As there are no channel associated with # any spiking unit (TO BE IMPLEMENTED), everything is written to # channel_group 0 for oe_group in openephys_rec.channel_groups: channel_ids = [ch.id for ch in oe_group.channels] channel_index = [ch.index for ch in oe_group.channels] chx = neo.ChannelIndex(name='channel group {}'.format(oe_group.id), channel_ids=channel_ids, index=channel_index, group_id=oe_group.id) # load output spt = np.load(op.join(openephys_directory, 'spike_times.npy')).flatten() spc = np.load(op.join(openephys_directory, 'spike_clusters.npy')).flatten() try: cgroup = np.loadtxt(op.join(openephys_directory, 'cluster_group.tsv'), dtype=[('cluster_id', 'i4'), ('group', 'U8')], skiprows=1) if cgroup.shape == (0, ): raise FileNotFoundError except FileNotFoundError: # manual corrections didn't happen; cgroup = np.array(list( zip(np.unique(spc), ['unsorted'] * np.unique(spc).size)), dtype=[('cluster_id', 'i4'), ('group', 'U8')]) for id, grp in cgroup: unit = neo.Unit(cluster_group=str(grp), cluster_id=id, name=id) unit.spiketrains.append( neo.SpikeTrain( times=(spt[spc == id].astype(float) / openephys_rec.sample_rate).simplified, t_stop=openephys_rec.duration, )) chx.units.append(unit) exdirio.write_channelindex(chx, start_time=0 * pq.s, stop_time=openephys_rec.duration) break else: raise ValueError(source + ' not supported')
def get_duration(data_path): f = exdir.File(str(data_path), 'r', plugins=[exdir.plugins.quantities]) return f.attrs['session_duration'].rescale('s')
def write_sorting(sorting, save_path, recording=None, sample_rate=None, save_waveforms=False, verbose=False): assert HAVE_EXDIR, "To use the ExdirExtractors run:\n\n pip install exdir\n\n" if sample_rate is None and recording is None: raise Exception("Provide 'sample_rate' argument (Hz)") else: if recording is None: sample_rate = sample_rate * pq.Hz else: sample_rate = recording.get_sampling_frequency() * pq.Hz exdir_group = exdir.File(save_path, plugins=exdir.plugins.quantities) ephys = exdir_group.require_group('processing').require_group('electrophysiology') ephys.attrs['sample_rate'] = sample_rate if 'group' in sorting.get_shared_unit_property_names(): channel_groups = np.unique([sorting.get_unit_property(unit, 'group') for unit in sorting.get_unit_ids()]) else: channel_groups = [0] if len(channel_groups) == 1: chan = 0 if verbose: print("Single group: ", chan) ch_group = ephys.require_group('channel_group_' + str(chan)) try: del ch_group['UnitTimes'] del ch_group['EventWaveform'] del ch_group['Clustering'] except Exception as e: pass unittimes = ch_group.require_group('UnitTimes') unit_stop_time = np.max([(np.max(sorting.get_unit_spike_train(u).astype(float) / sample_rate).rescale('s')) for u in sorting.get_unit_ids()]) * pq.s recording_stop_time = None if recording is not None: ch_group.attrs['electrode_group_id'] = chan ch_group.attrs['electrode_identities'] = np.array([]) ch_group.attrs['electrode_idx'] = np.arange(len(recording.get_channel_ids())) ch_group.attrs['start_time'] = 0 * pq.s recording_stop_time = recording.get_num_frames() / float(recording.get_sampling_frequency()) * pq.s unittimes.attrs['electrode_group_id'] = chan unittimes.attrs['electrode_identities'] = np.array([]) unittimes.attrs['electrode_idx'] = np.array(recording.get_channel_ids()) unittimes.attrs['start_time'] = 0 * pq.s ch_group.attrs['sample_rate'] = sample_rate if recording_stop_time is not None: unittimes.attrs['stop_time'] = recording_stop_time if recording_stop_time > unit_stop_time \ else unit_stop_time ch_group.attrs['stop_time'] = recording_stop_time if recording_stop_time > unit_stop_time \ else unit_stop_time nums = np.array([]) timestamps = np.array([]) waveforms = np.array([]) for unit in sorting.get_unit_ids(): unit_group = unittimes.require_group(str(unit)) unit_group.require_dataset('times', data=(sorting.get_unit_spike_train(unit).astype(float) / sample_rate).rescale('s')) unit_group.attrs['cluster_group'] = 'unsorted' unit_group.attrs['group_id'] = chan unit_group.attrs['name'] = 'unit #' + str(unit) timestamps = np.concatenate((timestamps, (sorting.get_unit_spike_train(unit).astype(float) / sample_rate).rescale('s'))) nums = np.concatenate((nums, [unit] * len(sorting.get_unit_spike_train(unit)))) if 'waveforms' in sorting.get_unit_spike_feature_names(unit): if len(waveforms) == 0: waveforms = sorting.get_unit_spike_features(unit, 'waveforms') else: waveforms = np.vstack((waveforms, sorting.get_unit_spike_features(unit, 'waveforms'))) if save_waveforms: if verbose: print("Saving EventWaveforms") if 'waveforms' in sorting.get_shared_unit_spike_feature_names(): eventwaveform = ch_group.require_group('EventWaveform') waveform_ts = eventwaveform.require_group('waveform_timeseries') data = waveform_ts.require_dataset('data', data=waveforms) waveform_ts.attrs['electrode_group_id'] = chan data.attrs['num_samples'] = len(waveforms) data.attrs['sample_rate'] = sample_rate data.attrs['unit'] = pq.dimensionless times = waveform_ts.require_dataset('timestamps', data=timestamps) times.attrs['num_samples'] = len(timestamps) times.attrs['unit'] = pq.s if recording is not None: waveform_ts.attrs['electrode_identities'] = np.array([]) waveform_ts.attrs['electrode_idx'] = np.arange(len(recording.get_channel_ids())) waveform_ts.attrs['start_time'] = 0 * pq.s if recording_stop_time is not None: waveform_ts.attrs['stop_time'] = recording_stop_time if recording_stop_time > unit_stop_time \ else unit_stop_time waveform_ts.attrs['sample_rate'] = sample_rate waveform_ts.attrs['sample_length'] = waveforms.shape[1] waveform_ts.attrs['num_samples'] = len(waveforms) if verbose: print("Saving Clustering") clustering = ch_group.require_group('Clustering') ts = clustering.require_dataset('timestamps', data=timestamps * pq.s) ts.attrs['num_samples'] = len(timestamps) ts.attrs['unit'] = pq.s ns = clustering.require_dataset('nums', data=nums) ns.attrs['num_samples'] = len(nums) cn = clustering.require_dataset('cluster_nums', data=np.array(sorting.get_unit_ids())) cn.attrs['num_samples'] = len(sorting.get_unit_ids()) else: # remove preexisten spike sorting data max_group = 10 for chan in np.arange(max_group): if 'channel_group_' + str(chan) in ephys.keys(): if verbose: print('Removing channel', chan, 'info') ch_group = ephys.require_group('channel_group_' + str(chan)) try: del ch_group['UnitTimes'] del ch_group['EventWaveform'] del ch_group['Clustering'] except Exception as e: pass channel_groups = np.unique([sorting.get_unit_property(unit, 'group') for unit in sorting.get_unit_ids()]) for chan in channel_groups: if verbose: print("Group: ", chan) ch_group = ephys.require_group('channel_group_' + str(chan)) unittimes = ch_group.require_group('UnitTimes') unit_stop_time = np.max([(np.max(sorting.get_unit_spike_train(u).astype(float) / sample_rate).rescale('s')) for u in sorting.get_unit_ids()]) * pq.s recording_stop_time = None if recording is not None: unittimes.attrs['electrode_group_id'] = chan unittimes.attrs['electrode_identities'] = np.array([]) unittimes.attrs['electrode_idx'] = np.array([ch for i_c, ch in enumerate(recording.get_channel_ids()) if recording.get_channel_property(ch, 'group') == chan]) unittimes.attrs['start_time'] = 0 * pq.s recording_stop_time = recording.get_num_frames() / float(recording.get_sampling_frequency()) * pq.s ch_group.attrs['electrode_group_id'] = chan ch_group.attrs['electrode_identities'] = np.array([i_c for i_c, ch in enumerate(recording.get_channel_ids()) if recording.get_channel_property(ch, 'group') == chan]) ch_group.attrs['electrode_idx'] = np.array([i_c for i_c, ch in enumerate(recording.get_channel_ids()) if recording.get_channel_property(ch, 'group') == chan]) ch_group.attrs['start_time'] = 0 * pq.s ch_group.attrs['sample_rate'] = sample_rate if recording_stop_time is not None: unittimes.attrs['stop_time'] = recording_stop_time if recording_stop_time > unit_stop_time \ else unit_stop_time ch_group.attrs['stop_time'] = recording_stop_time if recording_stop_time > unit_stop_time \ else unit_stop_time nums = np.array([]) timestamps = np.array([]) waveforms = np.array([]) for unit in sorting.get_unit_ids(): if sorting.get_unit_property(unit, 'group') == chan: if verbose: print("Unit: ", unit) unit_group = unittimes.require_group(str(unit)) unit_group.require_dataset('times', data=(sorting.get_unit_spike_train(unit).astype(float) / sample_rate).rescale('s')) unit_group.attrs['cluster_group'] = 'unsorted' unit_group.attrs['group_id'] = chan unit_group.attrs['name'] = 'unit #' + str(unit) timestamps = np.concatenate((timestamps, (sorting.get_unit_spike_train(unit).astype(float) / sample_rate).rescale('s'))) nums = np.concatenate((nums, [unit]*len(sorting.get_unit_spike_train(unit)))) if 'waveforms' in sorting.get_unit_spike_feature_names(unit): if len(waveforms) == 0: waveforms = sorting.get_unit_spike_features(unit, 'waveforms') else: waveforms = np.vstack((waveforms, sorting.get_unit_spike_features(unit, 'waveforms'))) if save_waveforms: if verbose: print("Saving EventWaveforms") if 'waveforms' in sorting.get_shared_unit_spike_feature_names(): eventwaveform = ch_group.require_group('EventWaveform') waveform_ts = eventwaveform.require_group('waveform_timeseries') data = waveform_ts.require_dataset('data', data=waveforms) data.attrs['num_samples'] = len(waveforms) data.attrs['sample_rate'] = sample_rate data.attrs['unit'] = pq.dimensionless times = waveform_ts.require_dataset('timestamps', data=timestamps) times.attrs['num_samples'] = len(timestamps) times.attrs['unit'] = pq.s waveform_ts.attrs['electrode_group_id'] = chan if recording is not None: waveform_ts.attrs['electrode_identities'] = np.array([]) waveform_ts.attrs['electrode_idx'] = np.array([ch for i_c, ch in enumerate(recording.get_channel_ids()) if recording.get_channel_property(ch, 'group') == chan]) waveform_ts.attrs['start_time'] = 0 * pq.s if recording_stop_time is not None: waveform_ts.attrs['stop_time'] = recording_stop_time if recording_stop_time > unit_stop_time \ else unit_stop_time waveform_ts.attrs['sample_rate'] = sample_rate waveform_ts.attrs['sample_length'] = waveforms.shape[1] waveform_ts.attrs['num_samples'] = len(waveforms) if verbose: print("Saving Clustering") clustering = ephys.require_group('channel_group_' + str(chan)).require_group('Clustering') ts = clustering.require_dataset('timestamps', data=timestamps*pq.s) ts.attrs['num_samples'] = len(timestamps) ts.attrs['unit'] = pq.s ns = clustering.require_dataset('nums', data=nums) ns.attrs['num_samples'] = len(nums) cn = clustering.require_dataset('cluster_nums', data=np.array(sorting.get_unit_ids())) cn.attrs['num_samples'] = len(sorting.get_unit_ids())
def write_recording(recording, save_path, lfp=False, mua=False): assert HAVE_EXDIR, "To use the ExdirExtractors run:\n\n pip install exdir\n\n" channel_ids = recording.get_channel_ids() raw = recording.get_traces() exdir_group = exdir.File(save_path, plugins=[exdir.plugins.quantities]) if not lfp and not mua: acq = exdir_group.require_group('acquisition') timeseries = acq.require_dataset('timeseries', data=raw) timeseries.attrs['sample_rate'] = recording.get_sampling_frequency() * pq.Hz timeseries.attrs['electrode_identities'] = np.array(channel_ids) return elif lfp: ephys = exdir_group.require_group('processing').require_group('electrophysiology') ephys.attrs['sample_rate'] = recording.get_sampling_frequency() * pq.Hz if 'group' in recording.get_shared_channel_property_names(): channel_groups = np.unique([recording.get_channel_property(ch, 'group') for ch in recording.get_channel_ids()]) else: channel_groups = [0] if len(channel_groups) == 1: chan = 0 ch_group = ephys.require_group('channel_group_' + str(chan)) lfp_group = ch_group.require_group('LFP') ch_group.attrs['electrode_group_id'] = chan ch_group.attrs['electrode_identities'] = np.array(recording.get_channel_ids()) ch_group.attrs['electrode_idx'] = np.arange(len(recording.get_channel_ids())) ch_group.attrs['start_time'] = 0 * pq.s ch_group.attrs['stop_time'] = recording.get_num_frames() / \ float(recording.get_sampling_frequency()) * pq.s for i_c, ch in enumerate(recording.get_channel_ids()): ts_group = lfp_group.require_group('LFP_timeseries_' + str(ch)) ts_group.attrs['electrode_group_id'] = chan ts_group.attrs['electrode_identity'] = ch ts_group.attrs['num_samples'] = recording.get_num_frames() ts_group.attrs['electrode_idx'] = i_c ts_group.attrs['start_time'] = 0 * pq.s ts_group.attrs['stop_time'] = recording.get_num_frames() / \ float(recording.get_sampling_frequency()) * pq.s ts_group.attrs['sample_rate'] = recording.get_sampling_frequency() * pq.Hz data = ts_group.require_dataset('data', data=recording.get_traces(channel_ids=[ch])) data.attrs['sample_rate'] = recording.get_sampling_frequency() * pq.Hz data.attrs['unit'] = pq.uV else: channel_groups = np.unique([recording.get_channel_property(ch, 'group') for ch in recording.get_channel_ids()]) for chan in channel_groups: ch_group = ephys.require_group('channel_group_' + str(chan)) lfp_group = ch_group.require_group('LFP') ch_group.attrs['electrode_group_id'] = chan ch_group.attrs['electrode_identities'] = np.array([ch for ch in recording.get_channel_ids() if recording.get_channel_property(ch, 'group') == chan]) ch_group.attrs['electrode_idx'] = np.array([i_c for i_c, ch in enumerate(recording.get_channel_ids()) if recording.get_channel_property(ch, 'group') == chan]) ch_group.attrs['start_time'] = 0 * pq.s ch_group.attrs['stop_time'] = recording.get_num_frames() / \ float(recording.get_sampling_frequency()) * pq.s for i_c, ch in enumerate(recording.get_channel_ids()): if recording.get_channel_property(ch, 'group') == chan: ts_group = lfp_group.require_group('LFP_timeseries_'+str(ch)) ts_group.attrs['electrode_group_id'] = chan ts_group.attrs['electrode_identity'] = ch ts_group.attrs['num_samples'] = recording.get_num_frames() ts_group.attrs['electrode_idx'] = i_c ts_group.attrs['start_time'] = 0 * pq.s ts_group.attrs['stop_time'] = recording.get_num_frames() / \ float(recording.get_sampling_frequency()) * pq.s ts_group.attrs['sample_rate'] = recording.get_sampling_frequency() * pq.Hz data = ts_group.require_dataset('data', data=recording.get_traces(channel_ids=[ch])) data.attrs['sample_rate'] = recording.get_sampling_frequency() * pq.Hz data.attrs['unit'] = pq.uV return elif mua: ephys = exdir_group.require_group('processing').require_group('electrophysiology') ephys.attrs['sample_rate'] = recording.get_sampling_frequency() * pq.Hz if 'group' in recording.get_shared_channel_property_names(): channel_groups = np.unique([recording.get_channel_property(ch, 'group') for ch in recording.get_channel_ids()]) else: channel_groups =[0] if len(channel_groups) == 1: chan = 0 ch_group = ephys.require_group('channel_group_' + str(chan)) mua_group = ch_group.require_group('MUA') ch_group.attrs['electrode_group_id'] = chan ch_group.attrs['electrode_identities'] = np.array(recording.get_channel_ids()) ch_group.attrs['electrode_idx'] = np.arange(len(recording.get_channel_ids())) ch_group.attrs['start_time'] = 0 * pq.s ch_group.attrs['stop_time'] = recording.get_num_frames() / \ float(recording.get_sampling_frequency()) * pq.s for i_c, ch in enumerate(recording.get_channel_ids()): ts_group = mua_group.require_group('MUA_timeseries_' + str(ch)) ts_group.attrs['electrode_group_id'] = chan ts_group.attrs['electrode_identity'] = ch ts_group.attrs['num_samples'] = recording.get_num_frames() ts_group.attrs['electrode_idx'] = i_c ts_group.attrs['start_time'] = 0 * pq.s ts_group.attrs['stop_time'] = recording.get_num_frames() / \ float(recording.get_sampling_frequency()) * pq.s ts_group.attrs['sample_rate'] = recording.get_sampling_frequency() * pq.Hz data = ts_group.require_dataset('data', data=recording.get_traces(channel_ids=[ch])) data.attrs['sample_rate'] = recording.get_sampling_frequency() * pq.Hz data.attrs['unit'] = pq.uV else: channel_groups = np.unique([recording.get_channel_property(ch, 'group') for ch in recording.get_channel_ids()]) for chan in channel_groups: ch_group = ephys.require_group('channel_group_' + str(chan)) mua_group = ch_group.require_group('MUA') ch_group.attrs['electrode_group_id'] = chan ch_group.attrs['electrode_identities'] = np.array([ch for ch in recording.get_channel_ids() if recording.get_channel_property(ch, 'group') == chan]) ch_group.attrs['electrode_idx'] = np.array([i_c for i_c, ch in enumerate(recording.get_channel_ids()) if recording.get_channel_property(ch, 'group') == chan]) ch_group.attrs['start_time'] = 0 * pq.s ch_group.attrs['stop_time'] = recording.get_num_frames() / \ float(recording.get_sampling_frequency()) * pq.s for i_c, ch in enumerate(recording.get_channel_ids()): if recording.get_channel_property(ch, 'group') == chan: ts_group = mua_group.require_group('MUA_timeseries_'+str(ch)) ts_group.attrs['electrode_group_id'] = chan ts_group.attrs['electrode_identity'] = ch ts_group.attrs['num_samples'] = recording.get_num_frames() ts_group.attrs['electrode_idx'] = i_c ts_group.attrs['start_time'] = 0 * pq.s ts_group.attrs['stop_time'] = recording.get_num_frames() / \ float(recording.get_sampling_frequency()) * pq.s ts_group.attrs['sample_rate'] = recording.get_sampling_frequency() * pq.Hz data = ts_group.require_dataset('data', data=recording.get_traces(channel_ids=[ch])) data.attrs['sample_rate'] = recording.get_sampling_frequency() * pq.Hz data.attrs['unit'] = pq.uV
def _add_to_Exdir(self, name, target_dir, level_dir): """_add_to_Exdir """ experiment = exdir.File("experiment.exdir") group = experiment.require_group(level_dir) experiment.require_raw(name)
def experiment_plot(project_path, action_id, n_channel=8, rem_channel="all", skip_channels=None, raster_start=-0.5, raster_stop=1): """ Plot raster, isi-mean-median, tuning (polar and linear) from visual data acquired through open-ephys and psychopy Parameters ---------- project_path: os.path, or equivallent Path to project directory action_id: string The experiment/action id n_channel: default 8; int Number of channel groups in recordings rem_channel: default "all"; "all" or int Signify what channels are of interest(=rem) skip_channels: default None; int or (list, array, tupule) ID of channel groups to skip raster_start: default -0.5; int or float When the rasters start on the x-axis relative to trial start, which is 0 raster_stop: default 1; int or float When the rasters stop on the x-axis relative to trial start, which is 0 Returns ------- savesfigures into exdir directory: main.exdir/figures """ if not (rem_channel == "all" or (isinstance(rem_channel, int) and rem_channel < n_channel)): msg = "rem_channel must be either 'all' or integer between 0 and n_channel ({}); not {}".format( n_channel, rem_channel) raise AttributeError(msg) # Define project tree project = expipe.get_project(project_path) action = project.actions[action_id] data_path = er.get_data_path(action) epochs = er.load_epochs(data_path) # Get data of interest (orients vs rates vs channel) oe_epoch = epochs[0] # openephys assert (oe_epoch.annotations['provenance'] == 'open-ephys') ps_epoch = epochs[1] # psychopy assert (ps_epoch.annotations['provenance'] == 'psychopy') # Create directory for figures exdir_file = exdir.File(data_path, plugins=exdir.plugins.quantities) figures_group = exdir_file.require_group('figures') raster_start = raster_start * pq.s raster_stop = raster_stop * pq.s orients = ps_epoch.labels # the labels are orrientations (135, 90, ...) def plot(channel_num, channel_path, spiketrains): # Create figures from spiketrains for spiketrain in spiketrains: try: if spiketrain.annotations["cluster_group"] == "noise": continue except KeyError: msg = "Cluster/channel group {} seems to not have been sorted in phys".format( channel_num) raise KeyError(msg) figure_id = "{}_{}_".format(channel_num, spiketrain.annotations['cluster_id']) sns.set() sns.set_style("white") # Raster plot processing trials = er.make_spiketrain_trials(spiketrain, oe_epoch, t_start=raster_start, t_stop=raster_stop) er.add_orientation_to_trials(trials, orients) orf_path = os.path.join(channel_path, figure_id + "orrient_raster.png") orient_raster_fig = orient_raster_plots(trials) orient_raster_fig.savefig(orf_path) # Orrientation vs spikefrequency plot (tuning curves) processing trials = er.make_spiketrain_trials(spiketrain, oe_epoch) er.add_orientation_to_trials(trials, orients) tf_path = os.path.join(channel_path, figure_id + "tuning.png") tuning_fig = plot_tuning_overview(trials, spiketrain) tuning_fig.savefig(tf_path) # Reset before next loop to save memory plt.close(fig="all") if rem_channel == "all": channels = range(n_channel) if isinstance(skip_channels, int): channels = [x for x in channels] del channels[skip_channels] elif isinstance(skip_channels, (list, tuple, type(empty(0)))): channels = [x for x in channels] for channel in skip_channels: del channels[channel] for channel in channels: channel_name = "channel_{}".format(channel) channel_group = figures_group.require_group(channel_name) channel_path = os.path.join(str(data_path), "figures\\" + channel_name) spiketrains = er.load_spiketrains(str(data_path), channel) plot(channel, channel_path, spiketrains) elif isinstance(rem_channel, int): channel_name = "channel_{}".format(rem_channel) channel_group = figures_group.require_group(channel_name) channel_path = os.path.join(str(data_path), "figures\\" + channel_name) spiketrains = er.load_spiketrains(str(data_path), rem_channel) plot(rem_channel, channel_path, spiketrains)
def writeSorting(sorting, exdir_file, recording=None, sample_rate=None): exdir, pq = _load_required_modules() if sample_rate is None and recording is None: raise Exception("Provide 'sample_rate' argument (Hz)") else: if recording is None: sample_rate = sample_rate * pq.Hz else: sample_rate = recording.getSamplingFrequency() * pq.Hz exdir_group = exdir.File(exdir_file, plugins=exdir.plugins.quantities) ephys = exdir_group.require_group('processing').require_group('electrophysiology') if 'group' in recording.getChannelPropertyNames(): channel_groups = np.unique([sorting.getUnitProperty(unit, 'group') for unit in sorting.getUnitIds()]) else: channel_groups = [0] if len(channel_groups) == 1: chan = 0 print("Single group: ", chan) ch_group = ephys.require_group('Channel_group_' + str(chan)) unittimes = ch_group.require_group('UnitTimes') eventwaveform = ch_group.require_group('EventWaveform') if recording is not None: ch_group.attrs['electrode_group_id'] = chan ch_group.attrs['electrode_identities'] = np.array([]) ch_group.attrs['electrode_idx'] = np.arange(len(recording.getChannelIds())) ch_group.attrs['start_time'] = 0 * pq.s ch_group.attrs['stop_time'] = recording.getNumFrames() / \ float(recording.getSamplingFrequency()) * pq.s unittimes.attrs['electrode_group_id'] = chan unittimes.attrs['electrode_identities'] = np.array([]) unittimes.attrs['electrode_idx'] = np.array(recording.getChannelIds()) unittimes.attrs['start_time'] = 0 * pq.s unittimes.attrs['stop_time'] = recording.getNumFrames() / \ float(recording.getSamplingFrequency()) * pq.s nums = np.array([]) timestamps = np.array([]) waveforms = np.array([]) for unit in sorting.getUnitIds(): unit_group = unittimes.require_group(str(unit)) unit_group.require_dataset('times', data=(sorting.getUnitSpikeTrain(unit).astype(float) / sample_rate).rescale('s')) unit_group.attrs['cluster_group'] = 'unsorted' unit_group.attrs['group_id'] = chan unit_group.attrs['name'] = 'unit #' + str(unit) timestamps = np.concatenate((timestamps, (sorting.getUnitSpikeTrain(unit).astype(float) / sample_rate).rescale('s'))) nums = np.concatenate((nums, [unit] * len(sorting.getUnitSpikeTrain(unit)))) if 'waveforms' in sorting.getUnitSpikeFeatureNames(unit): if len(waveforms) == 0: waveforms = sorting.getUnitSpikeFeatures(unit, 'waveforms') else: waveforms = np.vstack((waveforms, sorting.getUnitSpikeFeatures(unit, 'waveforms'))) print("Saving eventwaveforms and clustering") if 'waveforms' in sorting.getUnitSpikeFeatureNames(): waveform_ts = eventwaveform.require_group('waveform_timeseries') data = waveform_ts.require_dataset('data', data=waveforms) data.attrs['num_samples'] = len(waveforms) data.attrs['sample_rate'] = sample_rate data.attrs['unit'] = pq.dimensionless times = waveform_ts.require_dataset('timestamps', data=timestamps) times.attrs['num_samples'] = len(timestamps) times.attrs['unit'] = pq.s waveform_ts.attrs['electrode_group_id'] = chan if recording is not None: waveform_ts.attrs['electrode_identities'] = np.array([]) waveform_ts.attrs['electrode_idx'] = np.arange(len(recording.getChannelIds())) waveform_ts.attrs['start_time'] = 0 * pq.s waveform_ts.attrs['stop_time'] = recording.getNumFrames() / \ float(recording.getSamplingFrequency()) * pq.s waveform_ts.attrs['sample_rate'] = sample_rate waveform_ts.attrs['sample_length'] = waveforms.shape[1] waveform_ts.attrs['num_samples'] = len(waveforms) clustering = ephys.require_group('Channel_group_' + str(chan)).require_group('Clustering') ts = clustering.require_dataset('timestamps', data=timestamps * pq.s) ts.attrs['num_samples'] = len(timestamps) ts.attrs['unit'] = pq.s ns = clustering.require_dataset('nums', data=nums) ns.attrs['num_samples'] = len(nums) cn = clustering.require_dataset('cluster_nums', data=np.array(sorting.getUnitIds())) cn.attrs['num_samples'] = len(sorting.getUnitIds()) else: channel_groups = np.unique([sorting.getUnitProperty(unit, 'group') for unit in sorting.getUnitIds()]) for chan in channel_groups: print("Group: ", chan) ch_group = ephys.require_group('Channel_group_' + str(chan)) unittimes = ch_group.require_group('UnitTimes') eventwaveform = ch_group.require_group('EventWaveform') if recording is not None: unittimes.attrs['electrode_group_id'] = chan unittimes.attrs['electrode_identities'] = np.array([]) unittimes.attrs['electrode_idx'] = np.array([ch for i_c, ch in enumerate(recording.getChannelIds()) if recording.getChannelProperty(ch, 'group') == chan]) unittimes.attrs['start_time'] = 0 * pq.s unittimes.attrs['stop_time'] = recording.getNumFrames() / \ float(recording.getSamplingFrequency()) * pq.s ch_group.attrs['electrode_group_id'] = chan ch_group.attrs['electrode_identities'] = np.array([i_c for i_c, ch in enumerate(recording.getChannelIds()) if recording.getChannelProperty(ch, 'group') == chan]) ch_group.attrs['electrode_idx'] = np.array([i_c for i_c, ch in enumerate(recording.getChannelIds()) if recording.getChannelProperty(ch, 'group') == chan]) ch_group.attrs['start_time'] = 0 * pq.s ch_group.attrs['stop_time'] = recording.getNumFrames() / \ float(recording.getSamplingFrequency()) * pq.s nums = np.array([]) timestamps = np.array([]) waveforms = np.array([]) for unit in sorting.getUnitIds(): if sorting.getUnitProperty(unit, 'group') == chan: print("Unit: ", unit) unit_group = unittimes.require_group(str(unit)) unit_group.require_dataset('times', data=(sorting.getUnitSpikeTrain(unit).astype(float) / sample_rate).rescale('s')) unit_group.attrs['cluster_group'] = 'unsorted' unit_group.attrs['group_id'] = chan unit_group.attrs['name'] = 'unit #' + str(unit) timestamps = np.concatenate((timestamps, (sorting.getUnitSpikeTrain(unit).astype(float) / sample_rate).rescale('s'))) nums = np.concatenate((nums, [unit]*len(sorting.getUnitSpikeTrain(unit)))) if 'waveforms' in sorting.getUnitSpikeFeatureNames(unit): if len(waveforms) == 0: waveforms = sorting.getUnitSpikeFeatures(unit, 'waveforms') else: waveforms = np.vstack((waveforms, sorting.getUnitSpikeFeatures(unit, 'waveforms'))) print("Saving eventwaveforms and clustering") if 'waveforms' in sorting.getUnitSpikeFeatureNames(): waveform_ts = eventwaveform.require_group('waveform_timeseries') data = waveform_ts.require_dataset('data', data=waveforms) data.attrs['num_samples'] = len(waveforms) data.attrs['sample_rate'] = sample_rate data.attrs['unit'] = pq.dimensionless times = waveform_ts.require_dataset('timestamps', data=timestamps) times.attrs['num_samples'] = len(timestamps) times.attrs['unit'] = pq.s waveform_ts.attrs['electrode_group_id'] = chan if recording is not None: waveform_ts.attrs['electrode_identities'] = np.array([]) waveform_ts.attrs['electrode_idx'] = np.array([ch for i_c, ch in enumerate(recording.getChannelIds()) if recording.getChannelProperty(ch, 'group') == chan]) waveform_ts.attrs['start_time'] = 0 * pq.s waveform_ts.attrs['stop_time'] = recording.getNumFrames() / \ float(recording.getSamplingFrequency()) * pq.s waveform_ts.attrs['sample_rate'] = sample_rate waveform_ts.attrs['sample_length'] = waveforms.shape[1] waveform_ts.attrs['num_samples'] = len(waveforms) clustering = ephys.require_group('Channel_group_' + str(chan)).require_group('Clustering') ts = clustering.require_dataset('timestamps', data=timestamps*pq.s) ts.attrs['num_samples'] = len(timestamps) ts.attrs['unit'] = pq.s ns = clustering.require_dataset('nums', data=nums) ns.attrs['num_samples'] = len(nums) cn = clustering.require_dataset('cluster_nums', data=np.array(sorting.getUnitIds())) cn.attrs['num_samples'] = len(sorting.getUnitIds())
def writeRecording(recording, exdir_file, lfp=False, mua=False): exdir, pq = _load_required_modules() channel_ids = recording.getChannelIds() M = len(channel_ids) N = recording.getNumFrames() raw = recording.getTraces() exdir_group = exdir.File(exdir_file, plugins=exdir.plugins.quantities) if not lfp and not mua: timeseries = exdir_group.require_group('acquisition').require_dataset('timeseries', data=raw) timeseries.attrs['sample_rate'] = recording.getSamplingFrequency() * pq.Hz return elif lfp: ephys = exdir_group.require_group('processing').require_group('electrophysiology') if 'group' in recording.getChannelPropertyNames(): channel_groups = np.unique([recording.getChannelProperty(ch, 'group') for ch in recording.getChannelIds()]) else: channel_groups =[0] if len(channel_groups) == 1: chan = 0 ch_group = ephys.require_group('Channel_group_' + str(chan)) lfp_group = ch_group.require_group('LFP') ch_group.attrs['electrode_group_id'] = chan ch_group.attrs['electrode_identities'] = np.arange(len(recording.getChannelIds())) ch_group.attrs['electrode_idx'] = np.arange(len(recording.getChannelIds())) ch_group.attrs['start_time'] = 0 * pq.s ch_group.attrs['stop_time'] = recording.getNumFrames() / \ float(recording.getSamplingFrequency()) * pq.s for i_c, ch in enumerate(recording.getChannelIds()): ts_group = lfp_group.require_group('LFP_timeseries_' + str(ch)) ts_group.attrs['electrode_group_id'] = chan ts_group.attrs['electrode_identity'] = ch ts_group.attrs['num_samples'] = recording.getNumFrames() ts_group.attrs['electrode_idx'] = i_c ts_group.attrs['start_time'] = 0 * pq.s ts_group.attrs['stop_time'] = recording.getNumFrames() / \ float(recording.getSamplingFrequency()) * pq.s ts_group.attrs['sample_rate'] = recording.getSamplingFrequency() * pq.Hz data = ts_group.require_dataset('data', data=recording.getTraces(channel_ids=[ch])) data.attrs['sample_rate'] = recording.getSamplingFrequency() * pq.Hz data.attrs['unit'] = pq.uV else: channel_groups = np.unique([recording.getChannelProperty(ch, 'group') for ch in recording.getChannelIds()]) for chan in channel_groups: ch_group = ephys.require_group('Channel_group_' + str(chan)) lfp_group = ch_group.require_group('LFP') ch_group.attrs['electrode_group_id'] = chan ch_group.attrs['electrode_identities'] = np.array([i_c for i_c, ch in enumerate(recording.getChannelIds()) if recording.getChannelProperty(ch, 'group') == chan]) ch_group.attrs['electrode_idx'] = np.array([i_c for i_c, ch in enumerate(recording.getChannelIds()) if recording.getChannelProperty(ch, 'group') == chan]) ch_group.attrs['start_time'] = 0 * pq.s ch_group.attrs['stop_time'] = recording.getNumFrames() / \ float(recording.getSamplingFrequency()) * pq.s for i_c, ch in enumerate(recording.getChannelIds()): if recording.getChannelProperty(ch, 'group') == chan: ts_group = lfp_group.require_group('LFP_timeseries_'+str(ch)) ts_group.attrs['electrode_group_id'] = chan ts_group.attrs['electrode_identity'] = ch ts_group.attrs['num_samples'] = recording.getNumFrames() ts_group.attrs['electrode_idx'] = i_c ts_group.attrs['start_time'] = 0 * pq.s ts_group.attrs['stop_time'] = recording.getNumFrames() / \ float(recording.getSamplingFrequency()) * pq.s ts_group.attrs['sample_rate'] = recording.getSamplingFrequency() * pq.Hz data = ts_group.require_dataset('data', data=recording.getTraces(channel_ids=[ch])) data.attrs['sample_rate'] = recording.getSamplingFrequency() * pq.Hz data.attrs['unit'] = pq.uV return elif mua: ephys = exdir_group.require_group('processing').require_group('electrophysiology') if 'group' in recording.getChannelPropertyNames(): channel_groups = np.unique([recording.getChannelProperty(ch, 'group') for ch in recording.getChannelIds()]) else: channel_groups =[0] if len(channel_groups) == 1: chan = 0 ch_group = ephys.require_group('Channel_group_' + str(chan)) mua_group = ch_group.require_group('MUA') ch_group.attrs['electrode_group_id'] = chan ch_group.attrs['electrode_identities'] = np.arange(len(recording.getChannelIds())) ch_group.attrs['electrode_idx'] = np.arange(len(recording.getChannelIds())) ch_group.attrs['start_time'] = 0 * pq.s ch_group.attrs['stop_time'] = recording.getNumFrames() / \ float(recording.getSamplingFrequency()) * pq.s for i_c, ch in enumerate(recording.getChannelIds()): ts_group = mua_group.require_group('MUA_timeseries_' + str(ch)) ts_group.attrs['electrode_group_id'] = chan ts_group.attrs['electrode_identity'] = ch ts_group.attrs['num_samples'] = recording.getNumFrames() ts_group.attrs['electrode_idx'] = i_c ts_group.attrs['start_time'] = 0 * pq.s ts_group.attrs['stop_time'] = recording.getNumFrames() / \ float(recording.getSamplingFrequency()) * pq.s ts_group.attrs['sample_rate'] = recording.getSamplingFrequency() * pq.Hz data = ts_group.require_dataset('data', data=recording.getTraces(channel_ids=[ch])) data.attrs['sample_rate'] = recording.getSamplingFrequency() * pq.Hz data.attrs['unit'] = pq.uV else: channel_groups = np.unique([recording.getChannelProperty(ch, 'group') for ch in recording.getChannelIds()]) for chan in channel_groups: ch_group = ephys.require_group('Channel_group_' + str(chan)) mua_group = ch_group.require_group('MUA') ch_group.attrs['electrode_group_id'] = chan ch_group.attrs['electrode_identities'] = np.array([i_c for i_c, ch in enumerate(recording.getChannelIds()) if recording.getChannelProperty(ch, 'group') == chan]) ch_group.attrs['electrode_idx'] = np.array([i_c for i_c, ch in enumerate(recording.getChannelIds()) if recording.getChannelProperty(ch, 'group') == chan]) ch_group.attrs['start_time'] = 0 * pq.s ch_group.attrs['stop_time'] = recording.getNumFrames() / \ float(recording.getSamplingFrequency()) * pq.s for i_c, ch in enumerate(recording.getChannelIds()): if recording.getChannelProperty(ch, 'group') == chan: ts_group = mua_group.require_group('MUA_timeseries_'+str(ch)) ts_group.attrs['electrode_group_id'] = chan ts_group.attrs['electrode_identity'] = ch ts_group.attrs['num_samples'] = recording.getNumFrames() ts_group.attrs['electrode_idx'] = i_c ts_group.attrs['start_time'] = 0 * pq.s ts_group.attrs['stop_time'] = recording.getNumFrames() / \ float(recording.getSamplingFrequency()) * pq.s ts_group.attrs['sample_rate'] = recording.getSamplingFrequency() * pq.Hz data = ts_group.require_dataset('data', data=recording.getTraces(channel_ids=[ch])) data.attrs['sample_rate'] = recording.getSamplingFrequency() * pq.Hz data.attrs['unit'] = pq.uV
def __init__(self, exdir_file, sample_rate=None, channel_group=None, load_waveforms=False): assert HAVE_EXDIR, "To use the ExdirExtractors run:\n\n pip install exdir\n\n" SortingExtractor.__init__(self) self._exdir_file = exdir_file exdir_group = exdir.File(exdir_file, plugins=exdir.plugins.quantities) electrophysiology = None if 'processing' in exdir_group.keys(): if 'electrophysiology' in exdir_group['processing']: electrophysiology = exdir_group['processing'][ 'electrophysiology'] ephys_attrs = electrophysiology.attrs if 'sample_rate' in ephys_attrs: sample_rate = ephys_attrs['sample_rate'] else: if sample_rate is None: raise Exception( "Sampling rate information not found. Please provide it wiht the 'sample_rate' " "argument") else: sample_rate = sample_rate * pq.Hz self._sampling_frequency = float(sample_rate.rescale('Hz').magnitude) if electrophysiology is None: raise Exception("'electrophysiology' group not found!") self._unit_ids = [] current_unit = 1 self._spike_trains = [] for chan_name, channel in electrophysiology.items(): if 'channel' in chan_name: group = int(chan_name.split('_')[-1]) if channel_group is not None: if group != channel_group: continue if load_waveforms: if 'Clustering' in channel.keys( ) and 'EventWaveform' in channel.keys(): clustering = channel.require_group('Clustering') eventwaveform = channel.require_group('EventWaveform') nums = clustering['nums'].data waveforms = eventwaveform.require_group( 'waveform_timeseries')['data'].data if 'UnitTimes' in channel.keys(): for unit, unit_times in channel['UnitTimes'].items(): self._unit_ids.append(current_unit) self._spike_trains.append( (unit_times['times'].data.rescale('s') * sample_rate).magnitude) attrs = unit_times.attrs for k, v in attrs.items(): self.set_unit_property(current_unit, k, v) if load_waveforms: unit_idxs = np.where(nums == int(unit)) wf = waveforms[unit_idxs] self.set_unit_spike_features( current_unit, 'waveforms', wf) current_unit += 1
def setup(): testpath = str(tmpdir.mkdir("test").join("test.exdir")) if os.path.exists(testpath): shutil.rmtree(testpath) f = exdir.File(testpath) return ((f, ), {})
# This is a usecase that shows how h5py can be swapped with exdir. # usecase_h5py.py shows the same usecase with h5py instead import exdir import numpy as np time = np.linspace(0, 100, 101) voltage_1 = np.sin(time) voltage_2 = np.sin(time) + 10 f = exdir.File("experiments.exdir", "w") f.attrs['description'] = "This is a mock experiment with voltage values over time" # Creating group and datasets for experiment 1 grp_1 = f.create_group("experiment_1") dset_time_1 = grp_1.create_dataset("time", data=time) dset_time_1.attrs['unit'] = "ms" dset_voltage_1 = grp_1.create_dataset("voltage", data=voltage_1) dset_voltage_1.attrs['unit'] = "mV" # Creating group and datasets for experiment 2 grp_2 = f.create_group("experiment_2") dset_time_2 = grp_2.create_dataset("time", data=time) dset_time_2.attrs['unit'] = "ms"
def _openResources(self): """ Opens the root Dataset. """ logger.info("Opening: {}".format(self._fileName)) self._exdirGroup = exdir.File(self._fileName, mode='r')
def load_exdir(filename, db=None, filename_db=None, lazy=False): """ Loads measurement state from ExDir file system and database if the latter is provided. Parameters ---------- filename : str Absolute path to the exdir file. db : MyDatabase Binded pony database instance. lazy : bool If True, function leaves ExDir file open and sets retval.exdir to this file. Returns ------- MeasurementState : retval Measurement state that is obtained from combining data from ExDir by filename and from PostSQL by finding record with the same filename """ from time import time from sys import stdout # load_start = time() f = exdir.File(filename, 'r') # file_open_time = time() # stdout.flush() # print('load_exdir: file open time: ', file_open_time - load_start) # stdout.flush() if filename_db is None: filename_db = filename try: state = MeasurementState() if not lazy: state.metadata.update(f.attrs) else: state.metadata = f.attrs # metadata_time = time() # print ('load_exdir: metadata_time', metadata_time-file_open_time) # stdout.flush() for dataset_name in f.keys(): # dataset_start_time = time() parameters = [None for key in f[dataset_name]['parameters'].keys()] for parameter_id, parameter in f[dataset_name]['parameters'].items( ): if not lazy: #print (parameter.attrs) parameter_name = parameter.attrs['name'] parameter_setter = parameter.attrs['has_setter'] parameter_unit = parameter.attrs['unit'] parameter_values = parameter.data[:].copy() parameters[int(parameter_id)] = MeasurementParameter( parameter_values, parameter_setter, parameter_name, parameter_unit) else: parameters[int(parameter_id)] = LazyMeasParFromExdir( parameter) # parameter_time = time() # print ('load_exdir: dataset_parameter_time: ', parameter_time - dataset_start_time) # stdout.flush() if not lazy: try: data = f[dataset_name]['data'].data[:].copy() except: data = f[dataset_name]['data'].data else: data = f[dataset_name]['data'].data state.datasets[dataset_name] = MeasurementDataset(parameters, data) # dataset_end_time = time() # print ('load_exdir: dataset_data_time: ', dataset_end_time - parameter_time) # stdout.flush() if db: # get db record and add info to the returned measurement state db_record = get(i for i in db.Data if (i.filename == filename_db)) # print (filename) state.id = db_record.id state.start = db_record.start state.stop = db_record.stop state.measurement_type = db_record.measurement_type query = select(i for i in db.Reference if (i.this.id == state.id)) references = {} for q in query: references.update({q.ref_type: q.that.id}) # print(references) state.references = references state.filename = filename # print ('load_exdir: dataset_db_time: ', time() - dataset_end_time ) # stdout.flush() except Exception as e: raise e finally: if not lazy: f.close() else: state.exdir = f return state
def exdir_tmpfile(tmpdir): testpath = pathlib.Path(tmpdir.strpath) / "test.exdir" f = exdir.File(testpath, mode="w") yield f f.close() remove(testpath)
def quantities_tmpfile(tmpdir): testpath = pathlib.Path(tmpdir.strpath) / "test.exdir" f = exdir.File(testpath, mode="w", plugins=exdir.plugins.quantities) yield f f.close() remove(testpath)