def generate_spike_trains(exdir_path, openephys_rec, source='klusta'):
    import neo
    if source == 'klusta':  # TODO acquire features and masks
        print('Generating spike trains from KWIK file')
        exdir_file = exdir.File(exdir_path)
        acquisition = exdir_file["acquisition"]
        openephys_session = acquisition.attrs["openephys_session"]
        klusta_directory = op.join(str(acquisition.directory),
                                   openephys_session, 'klusta')
        n = 0
        for root, dirs, files in os.walk(klusta_directory):
            for f in files:
                if not f.endswith('_klusta.kwik'):
                    continue
                n += 1
                kwikfile = op.join(root, f)
                kwikio = neo.io.KwikIO(filename=kwikfile, )
                blk = kwikio.read_block(raw_data_units='uV')
                seg = blk.segments[0]
                try:
                    exdirio = neo.io.ExdirIO(exdir_path)
                    exdirio.write_block(blk)
                except Exception:
                    print('WARNING: unable to convert\n', kwikfile)
        if n == 0:
            raise IOError('.kwik file cannot be found in ' + klusta_directory)
    elif source == 'openephys':
        exdirio = neo.io.ExdirIO(exdir_path)
        for oe_group in openephys_rec.channel_groups:
            channel_ids = [ch.id for ch in oe_group.channels]
            channel_index = [ch.index for ch in oe_group.channels]
            chx = neo.ChannelIndex(name='channel group {}'.format(oe_group.id),
                                   channel_ids=channel_ids,
                                   index=channel_index,
                                   group_id=oe_group.id)
            for sptr in oe_group.spiketrains:
                unit = neo.Unit(cluster_group='unsorted',
                                cluster_id=sptr.attrs['cluster_id'],
                                name=sptr.attrs['name'])
                unit.spiketrains.append(
                    neo.SpikeTrain(times=sptr.times,
                                   waveforms=sptr.waveforms,
                                   sampling_rate=sptr.sample_rate,
                                   t_stop=sptr.t_stop,
                                   **sptr.attrs))
                chx.units.append(unit)
            exdirio.write_channelindex(chx,
                                       start_time=0 * pq.s,
                                       stop_time=openephys_rec.duration)
    elif source == 'kilosort':
        print('Generating spike trains from KiloSort')
        exdirio = neo.io.ExdirIO(exdir_path)
        exdir_file = exdir.File(exdir_path)
        openephys_directory = op.join(
            str(exdir_file["acquisition"].directory),
            exdir_file["acquisition"].attrs["openephys_session"])
        # iterate over channel groups. As there are no channel associated with
        # any spiking unit (TO BE IMPLEMENTED), everything is written to
        # channel_group 0
        for oe_group in openephys_rec.channel_groups:
            channel_ids = [ch.id for ch in oe_group.channels]
            channel_index = [ch.index for ch in oe_group.channels]
            chx = neo.ChannelIndex(name='channel group {}'.format(oe_group.id),
                                   channel_ids=channel_ids,
                                   index=channel_index,
                                   group_id=oe_group.id)
            # load output
            spt = np.load(op.join(openephys_directory,
                                  'spike_times.npy')).flatten()
            spc = np.load(op.join(openephys_directory,
                                  'spike_clusters.npy')).flatten()
            try:
                cgroup = np.loadtxt(op.join(openephys_directory,
                                            'cluster_group.tsv'),
                                    dtype=[('cluster_id', 'i4'),
                                           ('group', 'U8')],
                                    skiprows=1)
                if cgroup.shape == (0, ):
                    raise FileNotFoundError
            except FileNotFoundError:
                # manual corrections didn't happen;
                cgroup = np.array(list(
                    zip(np.unique(spc), ['unsorted'] * np.unique(spc).size)),
                                  dtype=[('cluster_id', 'i4'),
                                         ('group', 'U8')])
            for id, grp in cgroup:
                unit = neo.Unit(cluster_group=str(grp), cluster_id=id, name=id)
                unit.spiketrains.append(
                    neo.SpikeTrain(
                        times=(spt[spc == id].astype(float) /
                               openephys_rec.sample_rate).simplified,
                        t_stop=openephys_rec.duration,
                    ))
                chx.units.append(unit)
            exdirio.write_channelindex(chx,
                                       start_time=0 * pq.s,
                                       stop_time=openephys_rec.duration)
            break
    else:
        raise ValueError(source + ' not supported')
Beispiel #2
0
def get_duration(data_path):
    f = exdir.File(str(data_path), 'r', plugins=[exdir.plugins.quantities])

    return f.attrs['session_duration'].rescale('s')
Beispiel #3
0
    def write_sorting(sorting, save_path, recording=None, sample_rate=None, save_waveforms=False, verbose=False):
        assert HAVE_EXDIR, "To use the ExdirExtractors run:\n\n pip install exdir\n\n"
        if sample_rate is None and recording is None:
            raise Exception("Provide 'sample_rate' argument (Hz)")
        else:
            if recording is None:
                sample_rate = sample_rate * pq.Hz
            else:
                sample_rate = recording.get_sampling_frequency() * pq.Hz

        exdir_group = exdir.File(save_path, plugins=exdir.plugins.quantities)
        ephys = exdir_group.require_group('processing').require_group('electrophysiology')
        ephys.attrs['sample_rate'] = sample_rate

        if 'group' in sorting.get_shared_unit_property_names():
            channel_groups = np.unique([sorting.get_unit_property(unit, 'group') for unit in sorting.get_unit_ids()])
        else:
            channel_groups = [0]

        if len(channel_groups) == 1:
            chan = 0
            if verbose:
                print("Single group: ", chan)
            ch_group = ephys.require_group('channel_group_' + str(chan))
            try:
                del ch_group['UnitTimes']
                del ch_group['EventWaveform']
                del ch_group['Clustering']
            except Exception as e:
                pass
            unittimes = ch_group.require_group('UnitTimes')
            unit_stop_time = np.max([(np.max(sorting.get_unit_spike_train(u).astype(float) / sample_rate).rescale('s'))
                                     for u in sorting.get_unit_ids()]) * pq.s
            recording_stop_time = None
            if recording is not None:
                ch_group.attrs['electrode_group_id'] = chan
                ch_group.attrs['electrode_identities'] = np.array([])
                ch_group.attrs['electrode_idx'] = np.arange(len(recording.get_channel_ids()))
                ch_group.attrs['start_time'] = 0 * pq.s
                recording_stop_time = recording.get_num_frames() / float(recording.get_sampling_frequency()) * pq.s

                unittimes.attrs['electrode_group_id'] = chan
                unittimes.attrs['electrode_identities'] = np.array([])
                unittimes.attrs['electrode_idx'] = np.array(recording.get_channel_ids())
                unittimes.attrs['start_time'] = 0 * pq.s
            ch_group.attrs['sample_rate'] = sample_rate

            if recording_stop_time is not None:
                unittimes.attrs['stop_time'] = recording_stop_time if recording_stop_time > unit_stop_time \
                    else unit_stop_time
                ch_group.attrs['stop_time'] = recording_stop_time if recording_stop_time > unit_stop_time \
                    else unit_stop_time

            nums = np.array([])
            timestamps = np.array([])
            waveforms = np.array([])
            for unit in sorting.get_unit_ids():
                unit_group = unittimes.require_group(str(unit))
                unit_group.require_dataset('times',
                                           data=(sorting.get_unit_spike_train(unit).astype(float)
                                                 / sample_rate).rescale('s'))
                unit_group.attrs['cluster_group'] = 'unsorted'
                unit_group.attrs['group_id'] = chan
                unit_group.attrs['name'] = 'unit #' + str(unit)

                timestamps = np.concatenate((timestamps, (sorting.get_unit_spike_train(unit).astype(float)
                                                          / sample_rate).rescale('s')))
                nums = np.concatenate((nums, [unit] * len(sorting.get_unit_spike_train(unit))))

                if 'waveforms' in sorting.get_unit_spike_feature_names(unit):
                    if len(waveforms) == 0:
                        waveforms = sorting.get_unit_spike_features(unit, 'waveforms')
                    else:
                        waveforms = np.vstack((waveforms, sorting.get_unit_spike_features(unit, 'waveforms')))

            if save_waveforms:
                if verbose:
                    print("Saving EventWaveforms")
                if 'waveforms' in sorting.get_shared_unit_spike_feature_names():
                    eventwaveform = ch_group.require_group('EventWaveform')
                    waveform_ts = eventwaveform.require_group('waveform_timeseries')
                    data = waveform_ts.require_dataset('data', data=waveforms)
                    waveform_ts.attrs['electrode_group_id'] = chan
                    data.attrs['num_samples'] = len(waveforms)
                    data.attrs['sample_rate'] = sample_rate
                    data.attrs['unit'] = pq.dimensionless
                    times = waveform_ts.require_dataset('timestamps', data=timestamps)
                    times.attrs['num_samples'] = len(timestamps)
                    times.attrs['unit'] = pq.s
                    if recording is not None:
                        waveform_ts.attrs['electrode_identities'] = np.array([])
                        waveform_ts.attrs['electrode_idx'] = np.arange(len(recording.get_channel_ids()))
                        waveform_ts.attrs['start_time'] = 0 * pq.s
                        if recording_stop_time is not None:
                            waveform_ts.attrs['stop_time'] = recording_stop_time if recording_stop_time > unit_stop_time \
                                else unit_stop_time
                        waveform_ts.attrs['sample_rate'] = sample_rate
                        waveform_ts.attrs['sample_length'] = waveforms.shape[1]
                        waveform_ts.attrs['num_samples'] = len(waveforms)
                if verbose:
                    print("Saving Clustering")
                clustering = ch_group.require_group('Clustering')
                ts = clustering.require_dataset('timestamps', data=timestamps * pq.s)
                ts.attrs['num_samples'] = len(timestamps)
                ts.attrs['unit'] = pq.s
                ns = clustering.require_dataset('nums', data=nums)
                ns.attrs['num_samples'] = len(nums)
                cn = clustering.require_dataset('cluster_nums', data=np.array(sorting.get_unit_ids()))
                cn.attrs['num_samples'] = len(sorting.get_unit_ids())
        else:
            # remove preexisten spike sorting data
            max_group = 10
            for chan in np.arange(max_group):
                if 'channel_group_' + str(chan) in ephys.keys():
                    if verbose:
                        print('Removing channel', chan, 'info')
                    ch_group = ephys.require_group('channel_group_' + str(chan))
                    try:
                        del ch_group['UnitTimes']
                        del ch_group['EventWaveform']
                        del ch_group['Clustering']
                    except Exception as e:
                        pass
            channel_groups = np.unique([sorting.get_unit_property(unit, 'group') for unit in sorting.get_unit_ids()])
            for chan in channel_groups:
                if verbose:
                    print("Group: ", chan)
                ch_group = ephys.require_group('channel_group_' + str(chan))
                unittimes = ch_group.require_group('UnitTimes')
                unit_stop_time = np.max([(np.max(sorting.get_unit_spike_train(u).astype(float) / sample_rate).rescale('s'))
                                         for u in sorting.get_unit_ids()]) * pq.s
                recording_stop_time = None
                if recording is not None:
                    unittimes.attrs['electrode_group_id'] = chan
                    unittimes.attrs['electrode_identities'] = np.array([])
                    unittimes.attrs['electrode_idx'] = np.array([ch for i_c, ch in enumerate(recording.get_channel_ids())
                                                                 if recording.get_channel_property(ch, 'group') == chan])
                    unittimes.attrs['start_time'] = 0 * pq.s
                    recording_stop_time = recording.get_num_frames() / float(recording.get_sampling_frequency()) * pq.s

                    ch_group.attrs['electrode_group_id'] = chan
                    ch_group.attrs['electrode_identities'] = np.array([i_c for i_c, ch in enumerate(recording.get_channel_ids())
                                                                       if recording.get_channel_property(ch, 'group') == chan])
                    ch_group.attrs['electrode_idx'] = np.array([i_c for i_c, ch in enumerate(recording.get_channel_ids())
                                                                if recording.get_channel_property(ch, 'group') == chan])
                    ch_group.attrs['start_time'] = 0 * pq.s
                ch_group.attrs['sample_rate'] = sample_rate

                if recording_stop_time is not None:
                    unittimes.attrs['stop_time'] = recording_stop_time if recording_stop_time > unit_stop_time \
                        else unit_stop_time
                    ch_group.attrs['stop_time'] = recording_stop_time if recording_stop_time > unit_stop_time \
                        else unit_stop_time
                nums = np.array([])
                timestamps = np.array([])
                waveforms = np.array([])
                for unit in sorting.get_unit_ids():
                    if sorting.get_unit_property(unit, 'group') == chan:
                        if verbose:
                            print("Unit: ", unit)
                        unit_group = unittimes.require_group(str(unit))
                        unit_group.require_dataset('times',
                                                   data=(sorting.get_unit_spike_train(unit).astype(float)
                                                         / sample_rate).rescale('s'))
                        unit_group.attrs['cluster_group'] = 'unsorted'
                        unit_group.attrs['group_id'] = chan
                        unit_group.attrs['name'] = 'unit #' + str(unit)

                        timestamps = np.concatenate((timestamps, (sorting.get_unit_spike_train(unit).astype(float)
                                                         / sample_rate).rescale('s')))
                        nums = np.concatenate((nums, [unit]*len(sorting.get_unit_spike_train(unit))))

                        if 'waveforms' in sorting.get_unit_spike_feature_names(unit):
                            if len(waveforms) == 0:
                                waveforms = sorting.get_unit_spike_features(unit, 'waveforms')
                            else:
                                waveforms = np.vstack((waveforms, sorting.get_unit_spike_features(unit, 'waveforms')))
                if save_waveforms:
                    if verbose:
                        print("Saving EventWaveforms")
                    if 'waveforms' in sorting.get_shared_unit_spike_feature_names():
                        eventwaveform = ch_group.require_group('EventWaveform')
                        waveform_ts = eventwaveform.require_group('waveform_timeseries')
                        data = waveform_ts.require_dataset('data', data=waveforms)
                        data.attrs['num_samples'] = len(waveforms)
                        data.attrs['sample_rate'] = sample_rate
                        data.attrs['unit'] = pq.dimensionless
                        times = waveform_ts.require_dataset('timestamps', data=timestamps)
                        times.attrs['num_samples'] = len(timestamps)
                        times.attrs['unit'] = pq.s
                        waveform_ts.attrs['electrode_group_id'] = chan
                        if recording is not None:
                            waveform_ts.attrs['electrode_identities'] = np.array([])
                            waveform_ts.attrs['electrode_idx'] = np.array([ch for i_c, ch in
                                                                           enumerate(recording.get_channel_ids())
                                                                           if recording.get_channel_property(ch, 'group') == chan])
                            waveform_ts.attrs['start_time'] = 0 * pq.s
                            if recording_stop_time is not None:
                                waveform_ts.attrs['stop_time'] = recording_stop_time if recording_stop_time > unit_stop_time \
                                    else unit_stop_time
                            waveform_ts.attrs['sample_rate'] = sample_rate
                            waveform_ts.attrs['sample_length'] = waveforms.shape[1]
                            waveform_ts.attrs['num_samples'] = len(waveforms)
                if verbose:
                    print("Saving Clustering")
                clustering = ephys.require_group('channel_group_' + str(chan)).require_group('Clustering')
                ts = clustering.require_dataset('timestamps', data=timestamps*pq.s)
                ts.attrs['num_samples'] = len(timestamps)
                ts.attrs['unit'] = pq.s
                ns = clustering.require_dataset('nums', data=nums)
                ns.attrs['num_samples'] = len(nums)
                cn = clustering.require_dataset('cluster_nums', data=np.array(sorting.get_unit_ids()))
                cn.attrs['num_samples'] = len(sorting.get_unit_ids())
Beispiel #4
0
    def write_recording(recording, save_path, lfp=False, mua=False):
        assert HAVE_EXDIR, "To use the ExdirExtractors run:\n\n pip install exdir\n\n"
        channel_ids = recording.get_channel_ids()
        raw = recording.get_traces()
        exdir_group = exdir.File(save_path, plugins=[exdir.plugins.quantities])

        if not lfp and not mua:
            acq = exdir_group.require_group('acquisition')
            timeseries = acq.require_dataset('timeseries', data=raw)
            timeseries.attrs['sample_rate'] = recording.get_sampling_frequency() * pq.Hz
            timeseries.attrs['electrode_identities'] = np.array(channel_ids)
            return
        elif lfp:
            ephys = exdir_group.require_group('processing').require_group('electrophysiology')
            ephys.attrs['sample_rate'] = recording.get_sampling_frequency() * pq.Hz
            if 'group' in recording.get_shared_channel_property_names():
                channel_groups = np.unique([recording.get_channel_property(ch, 'group')
                                            for ch in recording.get_channel_ids()])
            else:
                channel_groups = [0]

            if len(channel_groups) == 1:
                chan = 0
                ch_group = ephys.require_group('channel_group_' + str(chan))
                lfp_group = ch_group.require_group('LFP')
                ch_group.attrs['electrode_group_id'] = chan
                ch_group.attrs['electrode_identities'] = np.array(recording.get_channel_ids())
                ch_group.attrs['electrode_idx'] = np.arange(len(recording.get_channel_ids()))
                ch_group.attrs['start_time'] = 0 * pq.s
                ch_group.attrs['stop_time'] = recording.get_num_frames() / \
                                              float(recording.get_sampling_frequency()) * pq.s
                for i_c, ch in enumerate(recording.get_channel_ids()):
                    ts_group = lfp_group.require_group('LFP_timeseries_' + str(ch))
                    ts_group.attrs['electrode_group_id'] = chan
                    ts_group.attrs['electrode_identity'] = ch
                    ts_group.attrs['num_samples'] = recording.get_num_frames()
                    ts_group.attrs['electrode_idx'] = i_c
                    ts_group.attrs['start_time'] = 0 * pq.s
                    ts_group.attrs['stop_time'] = recording.get_num_frames() / \
                                                  float(recording.get_sampling_frequency()) * pq.s
                    ts_group.attrs['sample_rate'] = recording.get_sampling_frequency() * pq.Hz
                    data = ts_group.require_dataset('data', data=recording.get_traces(channel_ids=[ch]))
                    data.attrs['sample_rate'] = recording.get_sampling_frequency() * pq.Hz
                    data.attrs['unit'] = pq.uV
            else:
                channel_groups = np.unique([recording.get_channel_property(ch, 'group')
                                            for ch in recording.get_channel_ids()])
                for chan in channel_groups:
                    ch_group = ephys.require_group('channel_group_' + str(chan))
                    lfp_group = ch_group.require_group('LFP')
                    ch_group.attrs['electrode_group_id'] = chan
                    ch_group.attrs['electrode_identities'] = np.array([ch for ch in recording.get_channel_ids()
                                                                       if recording.get_channel_property(ch, 'group')
                                                                       == chan])
                    ch_group.attrs['electrode_idx'] = np.array([i_c for i_c, ch in
                                                                enumerate(recording.get_channel_ids())
                                                                if recording.get_channel_property(ch, 'group') == chan])
                    ch_group.attrs['start_time'] = 0 * pq.s
                    ch_group.attrs['stop_time'] = recording.get_num_frames() / \
                                                   float(recording.get_sampling_frequency()) * pq.s
                    for i_c, ch in enumerate(recording.get_channel_ids()):
                        if recording.get_channel_property(ch, 'group') == chan:
                            ts_group = lfp_group.require_group('LFP_timeseries_'+str(ch))
                            ts_group.attrs['electrode_group_id'] = chan
                            ts_group.attrs['electrode_identity'] = ch
                            ts_group.attrs['num_samples'] = recording.get_num_frames()
                            ts_group.attrs['electrode_idx'] = i_c
                            ts_group.attrs['start_time'] = 0 * pq.s
                            ts_group.attrs['stop_time'] = recording.get_num_frames() / \
                                                          float(recording.get_sampling_frequency()) * pq.s
                            ts_group.attrs['sample_rate'] = recording.get_sampling_frequency() * pq.Hz
                            data = ts_group.require_dataset('data', data=recording.get_traces(channel_ids=[ch]))
                            data.attrs['sample_rate'] = recording.get_sampling_frequency() * pq.Hz
                            data.attrs['unit'] = pq.uV
            return
        elif mua:
            ephys = exdir_group.require_group('processing').require_group('electrophysiology')
            ephys.attrs['sample_rate'] = recording.get_sampling_frequency() * pq.Hz
            if 'group' in recording.get_shared_channel_property_names():
                channel_groups = np.unique([recording.get_channel_property(ch, 'group')
                                            for ch in recording.get_channel_ids()])
            else:
                channel_groups  =[0]

            if len(channel_groups) == 1:
                chan = 0
                ch_group = ephys.require_group('channel_group_' + str(chan))
                mua_group = ch_group.require_group('MUA')
                ch_group.attrs['electrode_group_id'] = chan
                ch_group.attrs['electrode_identities'] = np.array(recording.get_channel_ids())
                ch_group.attrs['electrode_idx'] = np.arange(len(recording.get_channel_ids()))
                ch_group.attrs['start_time'] = 0 * pq.s
                ch_group.attrs['stop_time'] = recording.get_num_frames() / \
                                              float(recording.get_sampling_frequency()) * pq.s
                for i_c, ch in enumerate(recording.get_channel_ids()):
                    ts_group = mua_group.require_group('MUA_timeseries_' + str(ch))
                    ts_group.attrs['electrode_group_id'] = chan
                    ts_group.attrs['electrode_identity'] = ch
                    ts_group.attrs['num_samples'] = recording.get_num_frames()
                    ts_group.attrs['electrode_idx'] = i_c
                    ts_group.attrs['start_time'] = 0 * pq.s
                    ts_group.attrs['stop_time'] = recording.get_num_frames() / \
                                                  float(recording.get_sampling_frequency()) * pq.s
                    ts_group.attrs['sample_rate'] = recording.get_sampling_frequency() * pq.Hz
                    data = ts_group.require_dataset('data', data=recording.get_traces(channel_ids=[ch]))
                    data.attrs['sample_rate'] = recording.get_sampling_frequency() * pq.Hz
                    data.attrs['unit'] = pq.uV
            else:
                channel_groups = np.unique([recording.get_channel_property(ch, 'group')
                                            for ch in recording.get_channel_ids()])
                for chan in channel_groups:
                    ch_group = ephys.require_group('channel_group_' + str(chan))
                    mua_group = ch_group.require_group('MUA')
                    ch_group.attrs['electrode_group_id'] = chan
                    ch_group.attrs['electrode_identities'] = np.array([ch for ch in recording.get_channel_ids()
                                                                       if recording.get_channel_property(ch, 'group')
                                                                       == chan])
                    ch_group.attrs['electrode_idx'] = np.array([i_c for i_c, ch in
                                                                enumerate(recording.get_channel_ids())
                                                                if recording.get_channel_property(ch, 'group') == chan])
                    ch_group.attrs['start_time'] = 0 * pq.s
                    ch_group.attrs['stop_time'] = recording.get_num_frames() / \
                                                   float(recording.get_sampling_frequency()) * pq.s
                    for i_c, ch in enumerate(recording.get_channel_ids()):
                        if recording.get_channel_property(ch, 'group') == chan:
                            ts_group = mua_group.require_group('MUA_timeseries_'+str(ch))
                            ts_group.attrs['electrode_group_id'] = chan
                            ts_group.attrs['electrode_identity'] = ch
                            ts_group.attrs['num_samples'] = recording.get_num_frames()
                            ts_group.attrs['electrode_idx'] = i_c
                            ts_group.attrs['start_time'] = 0 * pq.s
                            ts_group.attrs['stop_time'] = recording.get_num_frames() / \
                                                          float(recording.get_sampling_frequency()) * pq.s
                            ts_group.attrs['sample_rate'] = recording.get_sampling_frequency() * pq.Hz
                            data = ts_group.require_dataset('data', data=recording.get_traces(channel_ids=[ch]))
                            data.attrs['sample_rate'] = recording.get_sampling_frequency() * pq.Hz
                            data.attrs['unit'] = pq.uV
Beispiel #5
0
 def _add_to_Exdir(self, name, target_dir, level_dir):
     """_add_to_Exdir
     """
     experiment = exdir.File("experiment.exdir")
     group = experiment.require_group(level_dir)
     experiment.require_raw(name)
Beispiel #6
0
def experiment_plot(project_path,
                    action_id,
                    n_channel=8,
                    rem_channel="all",
                    skip_channels=None,
                    raster_start=-0.5,
                    raster_stop=1):
    """
    Plot raster, isi-mean-median, tuning (polar and linear) from visual data acquired through open-ephys and psychopy
    Parameters
    ----------
    project_path: os.path, or equivallent
        Path to project directory
    action_id: string
        The experiment/action id
    n_channel: default 8; int
        Number of channel groups in recordings
    rem_channel: default "all"; "all" or int
        Signify what channels are of interest(=rem)
    skip_channels: default None; int or (list, array, tupule)
        ID of channel groups to skip
    raster_start: default -0.5; int or float
        When the rasters start on the x-axis relative to trial start, which is 0
    raster_stop: default 1; int or float
        When the rasters stop on the x-axis relative to trial start, which is 0
    Returns
    -------
    savesfigures into exdir directory: main.exdir/figures
    """
    if not (rem_channel == "all" or
            (isinstance(rem_channel, int) and rem_channel < n_channel)):
        msg = "rem_channel must be either 'all' or integer between 0 and n_channel ({}); not {}".format(
            n_channel, rem_channel)
        raise AttributeError(msg)

    # Define project tree
    project = expipe.get_project(project_path)
    action = project.actions[action_id]
    data_path = er.get_data_path(action)
    epochs = er.load_epochs(data_path)

    # Get data of interest (orients vs rates vs channel)
    oe_epoch = epochs[0]  # openephys
    assert (oe_epoch.annotations['provenance'] == 'open-ephys')
    ps_epoch = epochs[1]  # psychopy
    assert (ps_epoch.annotations['provenance'] == 'psychopy')

    # Create directory for figures
    exdir_file = exdir.File(data_path, plugins=exdir.plugins.quantities)
    figures_group = exdir_file.require_group('figures')

    raster_start = raster_start * pq.s
    raster_stop = raster_stop * pq.s
    orients = ps_epoch.labels  # the labels are orrientations (135, 90, ...)

    def plot(channel_num, channel_path, spiketrains):
        # Create figures from spiketrains
        for spiketrain in spiketrains:
            try:
                if spiketrain.annotations["cluster_group"] == "noise":
                    continue
            except KeyError:
                msg = "Cluster/channel group {} seems to not have been sorted in phys".format(
                    channel_num)
                raise KeyError(msg)

            figure_id = "{}_{}_".format(channel_num,
                                        spiketrain.annotations['cluster_id'])

            sns.set()
            sns.set_style("white")
            # Raster plot processing
            trials = er.make_spiketrain_trials(spiketrain,
                                               oe_epoch,
                                               t_start=raster_start,
                                               t_stop=raster_stop)
            er.add_orientation_to_trials(trials, orients)
            orf_path = os.path.join(channel_path,
                                    figure_id + "orrient_raster.png")
            orient_raster_fig = orient_raster_plots(trials)
            orient_raster_fig.savefig(orf_path)

            # Orrientation vs spikefrequency plot (tuning curves) processing
            trials = er.make_spiketrain_trials(spiketrain, oe_epoch)
            er.add_orientation_to_trials(trials, orients)
            tf_path = os.path.join(channel_path, figure_id + "tuning.png")
            tuning_fig = plot_tuning_overview(trials, spiketrain)
            tuning_fig.savefig(tf_path)

            # Reset before next loop to save memory
            plt.close(fig="all")

    if rem_channel == "all":
        channels = range(n_channel)
        if isinstance(skip_channels, int):
            channels = [x for x in channels]
            del channels[skip_channels]
        elif isinstance(skip_channels, (list, tuple, type(empty(0)))):
            channels = [x for x in channels]
            for channel in skip_channels:
                del channels[channel]
        for channel in channels:
            channel_name = "channel_{}".format(channel)
            channel_group = figures_group.require_group(channel_name)
            channel_path = os.path.join(str(data_path),
                                        "figures\\" + channel_name)

            spiketrains = er.load_spiketrains(str(data_path), channel)

            plot(channel, channel_path, spiketrains)

    elif isinstance(rem_channel, int):
        channel_name = "channel_{}".format(rem_channel)
        channel_group = figures_group.require_group(channel_name)
        channel_path = os.path.join(str(data_path), "figures\\" + channel_name)

        spiketrains = er.load_spiketrains(str(data_path), rem_channel)

        plot(rem_channel, channel_path, spiketrains)
    def writeSorting(sorting, exdir_file, recording=None, sample_rate=None):
        exdir, pq = _load_required_modules()

        if sample_rate is None and recording is None:
            raise Exception("Provide 'sample_rate' argument (Hz)")
        else:
            if recording is None:
                sample_rate = sample_rate * pq.Hz
            else:
                sample_rate = recording.getSamplingFrequency() * pq.Hz

        exdir_group = exdir.File(exdir_file, plugins=exdir.plugins.quantities)
        ephys = exdir_group.require_group('processing').require_group('electrophysiology')

        if 'group' in recording.getChannelPropertyNames():
            channel_groups = np.unique([sorting.getUnitProperty(unit, 'group') for unit in sorting.getUnitIds()])
        else:
            channel_groups = [0]

        if len(channel_groups) == 1:
            chan = 0
            print("Single group: ", chan)
            ch_group = ephys.require_group('Channel_group_' + str(chan))
            unittimes = ch_group.require_group('UnitTimes')
            eventwaveform = ch_group.require_group('EventWaveform')
            if recording is not None:
                ch_group.attrs['electrode_group_id'] = chan
                ch_group.attrs['electrode_identities'] = np.array([])
                ch_group.attrs['electrode_idx'] = np.arange(len(recording.getChannelIds()))
                ch_group.attrs['start_time'] = 0 * pq.s
                ch_group.attrs['stop_time'] = recording.getNumFrames() / \
                                              float(recording.getSamplingFrequency()) * pq.s
                unittimes.attrs['electrode_group_id'] = chan
                unittimes.attrs['electrode_identities'] = np.array([])
                unittimes.attrs['electrode_idx'] = np.array(recording.getChannelIds())
                unittimes.attrs['start_time'] = 0 * pq.s
                unittimes.attrs['stop_time'] = recording.getNumFrames() / \
                                               float(recording.getSamplingFrequency()) * pq.s
            nums = np.array([])
            timestamps = np.array([])
            waveforms = np.array([])
            for unit in sorting.getUnitIds():
                unit_group = unittimes.require_group(str(unit))
                unit_group.require_dataset('times',
                                           data=(sorting.getUnitSpikeTrain(unit).astype(float)
                                                 / sample_rate).rescale('s'))
                unit_group.attrs['cluster_group'] = 'unsorted'
                unit_group.attrs['group_id'] = chan
                unit_group.attrs['name'] = 'unit #' + str(unit)

                timestamps = np.concatenate((timestamps, (sorting.getUnitSpikeTrain(unit).astype(float)
                                                          / sample_rate).rescale('s')))
                nums = np.concatenate((nums, [unit] * len(sorting.getUnitSpikeTrain(unit))))

                if 'waveforms' in sorting.getUnitSpikeFeatureNames(unit):
                    if len(waveforms) == 0:
                        waveforms = sorting.getUnitSpikeFeatures(unit, 'waveforms')
                    else:
                        waveforms = np.vstack((waveforms, sorting.getUnitSpikeFeatures(unit, 'waveforms')))

            print("Saving eventwaveforms and clustering")
            if 'waveforms' in sorting.getUnitSpikeFeatureNames():
                waveform_ts = eventwaveform.require_group('waveform_timeseries')
                data = waveform_ts.require_dataset('data', data=waveforms)
                data.attrs['num_samples'] = len(waveforms)
                data.attrs['sample_rate'] = sample_rate
                data.attrs['unit'] = pq.dimensionless
                times = waveform_ts.require_dataset('timestamps', data=timestamps)
                times.attrs['num_samples'] = len(timestamps)
                times.attrs['unit'] = pq.s
                waveform_ts.attrs['electrode_group_id'] = chan
                if recording is not None:
                    waveform_ts.attrs['electrode_identities'] = np.array([])
                    waveform_ts.attrs['electrode_idx'] = np.arange(len(recording.getChannelIds()))
                    waveform_ts.attrs['start_time'] = 0 * pq.s
                    waveform_ts.attrs['stop_time'] = recording.getNumFrames() / \
                                                     float(recording.getSamplingFrequency()) * pq.s
                    waveform_ts.attrs['sample_rate'] = sample_rate
                    waveform_ts.attrs['sample_length'] = waveforms.shape[1]
                    waveform_ts.attrs['num_samples'] = len(waveforms)
            clustering = ephys.require_group('Channel_group_' + str(chan)).require_group('Clustering')
            ts = clustering.require_dataset('timestamps', data=timestamps * pq.s)
            ts.attrs['num_samples'] = len(timestamps)
            ts.attrs['unit'] = pq.s
            ns = clustering.require_dataset('nums', data=nums)
            ns.attrs['num_samples'] = len(nums)
            cn = clustering.require_dataset('cluster_nums', data=np.array(sorting.getUnitIds()))
            cn.attrs['num_samples'] = len(sorting.getUnitIds())
        else:
            channel_groups = np.unique([sorting.getUnitProperty(unit, 'group') for unit in sorting.getUnitIds()])
            for chan in channel_groups:
                print("Group: ", chan)
                ch_group = ephys.require_group('Channel_group_' + str(chan))
                unittimes = ch_group.require_group('UnitTimes')
                eventwaveform = ch_group.require_group('EventWaveform')
                if recording is not None:
                    unittimes.attrs['electrode_group_id'] = chan
                    unittimes.attrs['electrode_identities'] = np.array([])
                    unittimes.attrs['electrode_idx'] = np.array([ch for i_c, ch in enumerate(recording.getChannelIds())
                                                                 if recording.getChannelProperty(ch, 'group') == chan])
                    unittimes.attrs['start_time'] = 0 * pq.s
                    unittimes.attrs['stop_time'] = recording.getNumFrames() / \
                                                  float(recording.getSamplingFrequency()) * pq.s
                    ch_group.attrs['electrode_group_id'] = chan
                    ch_group.attrs['electrode_identities'] = np.array([i_c for i_c, ch in enumerate(recording.getChannelIds())
                                                                       if recording.getChannelProperty(ch, 'group') == chan])
                    ch_group.attrs['electrode_idx'] = np.array([i_c for i_c, ch in enumerate(recording.getChannelIds())
                                                       if recording.getChannelProperty(ch, 'group') == chan])
                    ch_group.attrs['start_time'] = 0 * pq.s
                    ch_group.attrs['stop_time'] = recording.getNumFrames() / \
                                                  float(recording.getSamplingFrequency()) * pq.s
                nums = np.array([])
                timestamps = np.array([])
                waveforms = np.array([])
                for unit in sorting.getUnitIds():
                    if sorting.getUnitProperty(unit, 'group') == chan:
                        print("Unit: ", unit)
                        unit_group = unittimes.require_group(str(unit))
                        unit_group.require_dataset('times',
                                                   data=(sorting.getUnitSpikeTrain(unit).astype(float)
                                                         / sample_rate).rescale('s'))
                        unit_group.attrs['cluster_group'] = 'unsorted'
                        unit_group.attrs['group_id'] = chan
                        unit_group.attrs['name'] = 'unit #' + str(unit)

                        timestamps = np.concatenate((timestamps, (sorting.getUnitSpikeTrain(unit).astype(float)
                                                         / sample_rate).rescale('s')))
                        nums = np.concatenate((nums, [unit]*len(sorting.getUnitSpikeTrain(unit))))

                    if 'waveforms' in sorting.getUnitSpikeFeatureNames(unit):
                        if len(waveforms) == 0:
                            waveforms = sorting.getUnitSpikeFeatures(unit, 'waveforms')
                        else:
                            waveforms = np.vstack((waveforms, sorting.getUnitSpikeFeatures(unit, 'waveforms')))
                print("Saving eventwaveforms and clustering")
                if 'waveforms' in sorting.getUnitSpikeFeatureNames():
                    waveform_ts = eventwaveform.require_group('waveform_timeseries')
                    data = waveform_ts.require_dataset('data', data=waveforms)
                    data.attrs['num_samples'] = len(waveforms)
                    data.attrs['sample_rate'] = sample_rate
                    data.attrs['unit'] = pq.dimensionless
                    times = waveform_ts.require_dataset('timestamps', data=timestamps)
                    times.attrs['num_samples'] = len(timestamps)
                    times.attrs['unit'] = pq.s
                    waveform_ts.attrs['electrode_group_id'] = chan
                    if recording is not None:
                        waveform_ts.attrs['electrode_identities'] = np.array([])
                        waveform_ts.attrs['electrode_idx'] = np.array([ch for i_c, ch in
                                                                       enumerate(recording.getChannelIds())
                                                                       if recording.getChannelProperty(ch, 'group') == chan])
                        waveform_ts.attrs['start_time'] = 0 * pq.s
                        waveform_ts.attrs['stop_time'] = recording.getNumFrames() / \
                                                       float(recording.getSamplingFrequency()) * pq.s
                        waveform_ts.attrs['sample_rate'] = sample_rate
                        waveform_ts.attrs['sample_length'] = waveforms.shape[1]
                        waveform_ts.attrs['num_samples'] = len(waveforms)
                clustering = ephys.require_group('Channel_group_' + str(chan)).require_group('Clustering')
                ts = clustering.require_dataset('timestamps', data=timestamps*pq.s)
                ts.attrs['num_samples'] = len(timestamps)
                ts.attrs['unit'] = pq.s
                ns = clustering.require_dataset('nums', data=nums)
                ns.attrs['num_samples'] = len(nums)
                cn = clustering.require_dataset('cluster_nums', data=np.array(sorting.getUnitIds()))
                cn.attrs['num_samples'] = len(sorting.getUnitIds())
    def writeRecording(recording, exdir_file, lfp=False, mua=False):
        exdir, pq = _load_required_modules()

        channel_ids = recording.getChannelIds()
        M = len(channel_ids)
        N = recording.getNumFrames()
        raw = recording.getTraces()
        exdir_group = exdir.File(exdir_file, plugins=exdir.plugins.quantities)

        if not lfp and not mua:
            timeseries = exdir_group.require_group('acquisition').require_dataset('timeseries', data=raw)
            timeseries.attrs['sample_rate'] = recording.getSamplingFrequency() * pq.Hz
            return
        elif lfp:
            ephys = exdir_group.require_group('processing').require_group('electrophysiology')
            if 'group' in recording.getChannelPropertyNames():
                channel_groups = np.unique([recording.getChannelProperty(ch, 'group')
                                            for ch in recording.getChannelIds()])
            else:
                channel_groups  =[0]

            if len(channel_groups) == 1:
                chan = 0
                ch_group = ephys.require_group('Channel_group_' + str(chan))
                lfp_group = ch_group.require_group('LFP')
                ch_group.attrs['electrode_group_id'] = chan
                ch_group.attrs['electrode_identities'] = np.arange(len(recording.getChannelIds()))
                ch_group.attrs['electrode_idx'] = np.arange(len(recording.getChannelIds()))
                ch_group.attrs['start_time'] = 0 * pq.s
                ch_group.attrs['stop_time'] = recording.getNumFrames() / \
                                              float(recording.getSamplingFrequency()) * pq.s
                for i_c, ch in enumerate(recording.getChannelIds()):
                    ts_group = lfp_group.require_group('LFP_timeseries_' + str(ch))
                    ts_group.attrs['electrode_group_id'] = chan
                    ts_group.attrs['electrode_identity'] = ch
                    ts_group.attrs['num_samples'] = recording.getNumFrames()
                    ts_group.attrs['electrode_idx'] = i_c
                    ts_group.attrs['start_time'] = 0 * pq.s
                    ts_group.attrs['stop_time'] = recording.getNumFrames() / \
                                                  float(recording.getSamplingFrequency()) * pq.s
                    ts_group.attrs['sample_rate'] = recording.getSamplingFrequency() * pq.Hz
                    data = ts_group.require_dataset('data', data=recording.getTraces(channel_ids=[ch]))
                    data.attrs['sample_rate'] = recording.getSamplingFrequency() * pq.Hz
                    data.attrs['unit'] = pq.uV
            else:
                channel_groups = np.unique([recording.getChannelProperty(ch, 'group')
                                            for ch in recording.getChannelIds()])
                for chan in channel_groups:
                    ch_group = ephys.require_group('Channel_group_' + str(chan))
                    lfp_group = ch_group.require_group('LFP')
                    ch_group.attrs['electrode_group_id'] = chan
                    ch_group.attrs['electrode_identities'] = np.array([i_c for i_c, ch in enumerate(recording.getChannelIds())
                                                               if recording.getChannelProperty(ch, 'group') == chan])
                    ch_group.attrs['electrode_idx'] = np.array([i_c for i_c, ch in enumerate(recording.getChannelIds())
                                                        if recording.getChannelProperty(ch, 'group') == chan])
                    ch_group.attrs['start_time'] = 0 * pq.s
                    ch_group.attrs['stop_time'] = recording.getNumFrames() / \
                                                   float(recording.getSamplingFrequency()) * pq.s
                    for i_c, ch in enumerate(recording.getChannelIds()):
                        if recording.getChannelProperty(ch, 'group') == chan:
                            ts_group = lfp_group.require_group('LFP_timeseries_'+str(ch))
                            ts_group.attrs['electrode_group_id'] = chan
                            ts_group.attrs['electrode_identity'] = ch
                            ts_group.attrs['num_samples'] = recording.getNumFrames()
                            ts_group.attrs['electrode_idx'] = i_c
                            ts_group.attrs['start_time'] = 0 * pq.s
                            ts_group.attrs['stop_time'] = recording.getNumFrames() / \
                                                          float(recording.getSamplingFrequency()) * pq.s
                            ts_group.attrs['sample_rate'] = recording.getSamplingFrequency() * pq.Hz
                            data = ts_group.require_dataset('data', data=recording.getTraces(channel_ids=[ch]))
                            data.attrs['sample_rate'] = recording.getSamplingFrequency() * pq.Hz
                            data.attrs['unit'] = pq.uV
            return
        elif mua:
            ephys = exdir_group.require_group('processing').require_group('electrophysiology')
            if 'group' in recording.getChannelPropertyNames():
                channel_groups = np.unique([recording.getChannelProperty(ch, 'group')
                                            for ch in recording.getChannelIds()])
            else:
                channel_groups  =[0]

            if len(channel_groups) == 1:
                chan = 0
                ch_group = ephys.require_group('Channel_group_' + str(chan))
                mua_group = ch_group.require_group('MUA')
                ch_group.attrs['electrode_group_id'] = chan
                ch_group.attrs['electrode_identities'] = np.arange(len(recording.getChannelIds()))
                ch_group.attrs['electrode_idx'] = np.arange(len(recording.getChannelIds()))
                ch_group.attrs['start_time'] = 0 * pq.s
                ch_group.attrs['stop_time'] = recording.getNumFrames() / \
                                              float(recording.getSamplingFrequency()) * pq.s
                for i_c, ch in enumerate(recording.getChannelIds()):
                    ts_group = mua_group.require_group('MUA_timeseries_' + str(ch))
                    ts_group.attrs['electrode_group_id'] = chan
                    ts_group.attrs['electrode_identity'] = ch
                    ts_group.attrs['num_samples'] = recording.getNumFrames()
                    ts_group.attrs['electrode_idx'] = i_c
                    ts_group.attrs['start_time'] = 0 * pq.s
                    ts_group.attrs['stop_time'] = recording.getNumFrames() / \
                                                  float(recording.getSamplingFrequency()) * pq.s
                    ts_group.attrs['sample_rate'] = recording.getSamplingFrequency() * pq.Hz
                    data = ts_group.require_dataset('data', data=recording.getTraces(channel_ids=[ch]))
                    data.attrs['sample_rate'] = recording.getSamplingFrequency() * pq.Hz
                    data.attrs['unit'] = pq.uV
            else:
                channel_groups = np.unique([recording.getChannelProperty(ch, 'group')
                                            for ch in recording.getChannelIds()])
                for chan in channel_groups:
                    ch_group = ephys.require_group('Channel_group_' + str(chan))
                    mua_group = ch_group.require_group('MUA')
                    ch_group.attrs['electrode_group_id'] = chan
                    ch_group.attrs['electrode_identities'] = np.array([i_c for i_c, ch in enumerate(recording.getChannelIds())
                                                               if recording.getChannelProperty(ch, 'group') == chan])
                    ch_group.attrs['electrode_idx'] = np.array([i_c for i_c, ch in enumerate(recording.getChannelIds())
                                                        if recording.getChannelProperty(ch, 'group') == chan])
                    ch_group.attrs['start_time'] = 0 * pq.s
                    ch_group.attrs['stop_time'] = recording.getNumFrames() / \
                                                   float(recording.getSamplingFrequency()) * pq.s
                    for i_c, ch in enumerate(recording.getChannelIds()):
                        if recording.getChannelProperty(ch, 'group') == chan:
                            ts_group = mua_group.require_group('MUA_timeseries_'+str(ch))
                            ts_group.attrs['electrode_group_id'] = chan
                            ts_group.attrs['electrode_identity'] = ch
                            ts_group.attrs['num_samples'] = recording.getNumFrames()
                            ts_group.attrs['electrode_idx'] = i_c
                            ts_group.attrs['start_time'] = 0 * pq.s
                            ts_group.attrs['stop_time'] = recording.getNumFrames() / \
                                                          float(recording.getSamplingFrequency()) * pq.s
                            ts_group.attrs['sample_rate'] = recording.getSamplingFrequency() * pq.Hz
                            data = ts_group.require_dataset('data', data=recording.getTraces(channel_ids=[ch]))
                            data.attrs['sample_rate'] = recording.getSamplingFrequency() * pq.Hz
                            data.attrs['unit'] = pq.uV
Beispiel #9
0
    def __init__(self,
                 exdir_file,
                 sample_rate=None,
                 channel_group=None,
                 load_waveforms=False):
        assert HAVE_EXDIR, "To use the ExdirExtractors run:\n\n pip install exdir\n\n"
        SortingExtractor.__init__(self)
        self._exdir_file = exdir_file
        exdir_group = exdir.File(exdir_file, plugins=exdir.plugins.quantities)

        electrophysiology = None
        if 'processing' in exdir_group.keys():
            if 'electrophysiology' in exdir_group['processing']:
                electrophysiology = exdir_group['processing'][
                    'electrophysiology']
                ephys_attrs = electrophysiology.attrs
                if 'sample_rate' in ephys_attrs:
                    sample_rate = ephys_attrs['sample_rate']
        else:
            if sample_rate is None:
                raise Exception(
                    "Sampling rate information not found. Please provide it wiht the 'sample_rate' "
                    "argument")
            else:
                sample_rate = sample_rate * pq.Hz
        self._sampling_frequency = float(sample_rate.rescale('Hz').magnitude)

        if electrophysiology is None:
            raise Exception("'electrophysiology' group not found!")

        self._unit_ids = []
        current_unit = 1
        self._spike_trains = []
        for chan_name, channel in electrophysiology.items():
            if 'channel' in chan_name:
                group = int(chan_name.split('_')[-1])
                if channel_group is not None:
                    if group != channel_group:
                        continue
                if load_waveforms:
                    if 'Clustering' in channel.keys(
                    ) and 'EventWaveform' in channel.keys():
                        clustering = channel.require_group('Clustering')
                        eventwaveform = channel.require_group('EventWaveform')
                        nums = clustering['nums'].data
                        waveforms = eventwaveform.require_group(
                            'waveform_timeseries')['data'].data
                if 'UnitTimes' in channel.keys():
                    for unit, unit_times in channel['UnitTimes'].items():
                        self._unit_ids.append(current_unit)
                        self._spike_trains.append(
                            (unit_times['times'].data.rescale('s') *
                             sample_rate).magnitude)
                        attrs = unit_times.attrs
                        for k, v in attrs.items():
                            self.set_unit_property(current_unit, k, v)
                        if load_waveforms:
                            unit_idxs = np.where(nums == int(unit))
                            wf = waveforms[unit_idxs]
                            self.set_unit_spike_features(
                                current_unit, 'waveforms', wf)
                        current_unit += 1
Beispiel #10
0
 def setup():
     testpath = str(tmpdir.mkdir("test").join("test.exdir"))
     if os.path.exists(testpath):
         shutil.rmtree(testpath)
     f = exdir.File(testpath)
     return ((f, ), {})
Beispiel #11
0
# This is a usecase that shows how h5py can be swapped with exdir.
# usecase_h5py.py shows the same usecase with h5py instead

import exdir
import numpy as np

time = np.linspace(0, 100, 101)

voltage_1 = np.sin(time)
voltage_2 = np.sin(time) + 10


f = exdir.File("experiments.exdir", "w")
f.attrs['description'] = "This is a mock experiment with voltage values over time"


# Creating group and datasets for experiment 1
grp_1 = f.create_group("experiment_1")

dset_time_1 = grp_1.create_dataset("time", data=time)
dset_time_1.attrs['unit'] = "ms"

dset_voltage_1 = grp_1.create_dataset("voltage", data=voltage_1)
dset_voltage_1.attrs['unit'] = "mV"


# Creating group and datasets for experiment 2
grp_2 = f.create_group("experiment_2")

dset_time_2 = grp_2.create_dataset("time", data=time)
dset_time_2.attrs['unit'] = "ms"
Beispiel #12
0
 def _openResources(self):
     """ Opens the root Dataset.
     """
     logger.info("Opening: {}".format(self._fileName))
     self._exdirGroup = exdir.File(self._fileName, mode='r')
def load_exdir(filename, db=None, filename_db=None, lazy=False):
    """
    Loads measurement state from ExDir file system and database if the latter is provided.

    Parameters
    ----------
    filename : str
        Absolute path to the exdir file.
    db : MyDatabase
        Binded pony database instance.
    lazy : bool
        If True, function leaves ExDir file open and sets
        retval.exdir to this file.


    Returns
    -------
    MeasurementState : retval
        Measurement state that is obtained from combining data from ExDir by filename and from
        PostSQL by finding record with the same filename
    """
    from time import time
    from sys import stdout

    # load_start = time()
    f = exdir.File(filename, 'r')
    # file_open_time = time()
    # stdout.flush()
    # print('load_exdir: file open time: ',  file_open_time - load_start)
    # stdout.flush()

    if filename_db is None:
        filename_db = filename

    try:
        state = MeasurementState()
        if not lazy:
            state.metadata.update(f.attrs)
        else:
            state.metadata = f.attrs

        # metadata_time = time()
        # print ('load_exdir: metadata_time', metadata_time-file_open_time)
        # stdout.flush()

        for dataset_name in f.keys():
            # dataset_start_time = time()
            parameters = [None for key in f[dataset_name]['parameters'].keys()]
            for parameter_id, parameter in f[dataset_name]['parameters'].items(
            ):
                if not lazy:
                    #print (parameter.attrs)
                    parameter_name = parameter.attrs['name']
                    parameter_setter = parameter.attrs['has_setter']
                    parameter_unit = parameter.attrs['unit']
                    parameter_values = parameter.data[:].copy()
                    parameters[int(parameter_id)] = MeasurementParameter(
                        parameter_values, parameter_setter, parameter_name,
                        parameter_unit)
                else:
                    parameters[int(parameter_id)] = LazyMeasParFromExdir(
                        parameter)
            # parameter_time = time()
            # print ('load_exdir: dataset_parameter_time: ', parameter_time - dataset_start_time)
            # stdout.flush()
            if not lazy:
                try:
                    data = f[dataset_name]['data'].data[:].copy()
                except:
                    data = f[dataset_name]['data'].data
            else:
                data = f[dataset_name]['data'].data
            state.datasets[dataset_name] = MeasurementDataset(parameters, data)
            # dataset_end_time = time()
        # print ('load_exdir: dataset_data_time: ', dataset_end_time - parameter_time)
        # stdout.flush()

        if db:
            # get db record and add info to the returned measurement state
            db_record = get(i for i in db.Data if (i.filename == filename_db))
            # print (filename)
            state.id = db_record.id
            state.start = db_record.start
            state.stop = db_record.stop
            state.measurement_type = db_record.measurement_type
            query = select(i for i in db.Reference if (i.this.id == state.id))
            references = {}
            for q in query:
                references.update({q.ref_type: q.that.id})
            # print(references)
            state.references = references
            state.filename = filename
        # print ('load_exdir: dataset_db_time: ', time() - dataset_end_time )
        # stdout.flush()
    except Exception as e:
        raise e
    finally:
        if not lazy:
            f.close()
        else:
            state.exdir = f
    return state
Beispiel #14
0
def exdir_tmpfile(tmpdir):
    testpath = pathlib.Path(tmpdir.strpath) / "test.exdir"
    f = exdir.File(testpath, mode="w")
    yield f
    f.close()
    remove(testpath)
Beispiel #15
0
def quantities_tmpfile(tmpdir):
    testpath = pathlib.Path(tmpdir.strpath) / "test.exdir"
    f = exdir.File(testpath, mode="w", plugins=exdir.plugins.quantities)
    yield f
    f.close()
    remove(testpath)