示例#1
0
    def read_block(self, lazy=False, cascade=True):
        """

        """

        blk = Block()
        if cascade:
            seg = Segment(file_origin=self._absolute_filename)

            blk.channel_indexes = self._channel_indexes

            blk.segments += [seg]

            seg.analogsignals = self.read_analogsignal(lazy=lazy,
                                                       cascade=cascade)
            try:
                seg.irregularlysampledsignals = self.read_tracking()
            except Exception as e:
                print('Warning: unable to read tracking')
                print(e)
            seg.spiketrains = self.read_spiketrain()

            # TODO Call all other read functions

            seg.duration = self._duration

            # TODO May need to "populate_RecordingChannel"

            # spiketrain = self.read_spiketrain()

            # seg.spiketrains.append()

        blk.create_many_to_one_relationship()
        return blk
示例#2
0
def prune_segment(segment: Segment) -> None:
    segment.analogsignals = [a for a in segment.analogsignals if "type_id" in a.annotations]
    segment.epochs = [ep for ep in segment.epochs if "type_id" in ep.annotations]
    segment.events = [ev for ev in segment.events if "type_id" in ev.annotations]
    segment.irregularlysampledsignals = [i for i in segment.irregularlysampledsignals if "type_id" in i.annotations]
    segment.spiketrains = [st for st in segment.spiketrains if "type_id" in st.annotations]
    segment.imagesequences = [i for i in segment.imagesequences if "type_id" in i.annotations]
示例#3
0
文件: gdfio.py 项目: rgutzen/UP-Tasks
    def read_segment(self,
                     lazy=False,
                     cascade=True,
                     gdf_id_list=None,
                     time_unit=pq.ms,
                     t_start=None,
                     t_stop=None,
                     id_column=0,
                     time_column=1,
                     **args):
        """
        Read a Segment which contains SpikeTrain(s) with specified neuron IDs
        from the GDF data.

        Parameters
        ----------
        lazy : bool, optional, default: False
        cascade : bool, optional, default: True
        gdf_id_list : list or tuple, default: None
            Can be either list of GDF IDs of which to return SpikeTrain(s) or
            a tuple specifying the range (includes boundaries [start, stop])
            of GDF IDs. Must be specified if the GDF file contains neuron
            IDs, the default None then raises an error. Specify an empty
            list [] to retrieve the spike trains of all neurons with at least
            one spike.
        time_unit : Quantity (time), optional, default: quantities.ms
            The time unit of recorded time stamps.
        t_start : Quantity (time), default: None
            Start time of SpikeTrain. t_start must be specified, the default None
            raises an error.
        t_stop : Quantity (time), default: None
            Stop time of SpikeTrain. t_stop must be specified, the default None
            raises an error.
        id_column : int, optional, default: 0
            Column index of neuron IDs.
        time_column : int, optional, default: 1
            Column index of time stamps.

        Returns
        -------
        seg : Segment
            The Segment contains one SpikeTrain for each ID in gdf_id_list.
        """

        if isinstance(gdf_id_list, tuple):
            gdf_id_list = range(gdf_id_list[0], gdf_id_list[1] + 1)

        # __read_spiketrains() needs a list of IDs
        if gdf_id_list is None:
            gdf_id_list = [None]

        # create an empty Segment and fill in the spike trains
        seg = Segment()
        seg.spiketrains = self.__read_spiketrains(gdf_id_list, time_unit,
                                                  t_start, t_stop, id_column,
                                                  time_column, **args)

        return seg
示例#4
0
    def _read_segment(self, node, parent):
        attributes = self._get_standard_attributes(node)
        segment = Segment(**attributes)

        signals = []
        for name, child_node in node['analogsignals'].items():
            if "AnalogSignal" in name:
                signals.append(self._read_analogsignal(child_node, parent=segment))
        if signals and self.merge_singles:
            segment.unmerged_analogsignals = signals  # signals will be merged later
            signals = []
        for name, child_node in node['analogsignalarrays'].items():
            if "AnalogSignalArray" in name:
                signals.append(self._read_analogsignalarray(child_node, parent=segment))
        segment.analogsignals = signals

        irr_signals = []
        for name, child_node in node['irregularlysampledsignals'].items():
            if "IrregularlySampledSignal" in name:
                irr_signals.append(self._read_irregularlysampledsignal(child_node, parent=segment))
        if irr_signals and self.merge_singles:
            segment.unmerged_irregularlysampledsignals = irr_signals
            irr_signals = []
        segment.irregularlysampledsignals = irr_signals

        epochs = []
        for name, child_node in node['epochs'].items():
            if "Epoch" in name:
                epochs.append(self._read_epoch(child_node, parent=segment))
        if self.merge_singles:
            epochs = self._merge_data_objects(epochs)
        for name, child_node in node['epocharrays'].items():
            if "EpochArray" in name:
                epochs.append(self._read_epocharray(child_node, parent=segment))
        segment.epochs = epochs

        events = []
        for name, child_node in node['events'].items():
            if "Event" in name:
                events.append(self._read_event(child_node, parent=segment))
        if self.merge_singles:
            events = self._merge_data_objects(events)
        for name, child_node in node['eventarrays'].items():
            if "EventArray" in name:
                events.append(self._read_eventarray(child_node, parent=segment))
        segment.events = events

        spiketrains = []
        for name, child_node in node['spikes'].items():
            raise NotImplementedError('Spike objects not yet handled.')
        for name, child_node in node['spiketrains'].items():
            if "SpikeTrain" in name:
                spiketrains.append(self._read_spiketrain(child_node, parent=segment))
        segment.spiketrains = spiketrains

        segment.block = parent
        return segment
示例#5
0
    def _read_segment(self, node, parent):
        attributes = self._get_standard_attributes(node)
        segment = Segment(**attributes)

        signals = []
        for name, child_node in node['analogsignals'].items():
            if "AnalogSignal" in name:
                signals.append(self._read_analogsignal(child_node, parent=segment))
        if signals and self.merge_singles:
            segment.unmerged_analogsignals = signals  # signals will be merged later
            signals = []
        for name, child_node in node['analogsignalarrays'].items():
            if "AnalogSignalArray" in name:
                signals.append(self._read_analogsignalarray(child_node, parent=segment))
        segment.analogsignals = signals

        irr_signals = []
        for name, child_node in node['irregularlysampledsignals'].items():
            if "IrregularlySampledSignal" in name:
                irr_signals.append(self._read_irregularlysampledsignal(child_node, parent=segment))
        if irr_signals and self.merge_singles:
            segment.unmerged_irregularlysampledsignals = irr_signals
            irr_signals = []
        segment.irregularlysampledsignals = irr_signals

        epochs = []
        for name, child_node in node['epochs'].items():
            if "Epoch" in name:
                epochs.append(self._read_epoch(child_node, parent=segment))
        if self.merge_singles:
            epochs = self._merge_data_objects(epochs)
        for name, child_node in node['epocharrays'].items():
            if "EpochArray" in name:
                epochs.append(self._read_epocharray(child_node, parent=segment))
        segment.epochs = epochs

        events = []
        for name, child_node in node['events'].items():
            if "Event" in name:
                events.append(self._read_event(child_node, parent=segment))
        if self.merge_singles:
            events = self._merge_data_objects(events)
        for name, child_node in node['eventarrays'].items():
            if "EventArray" in name:
                events.append(self._read_eventarray(child_node, parent=segment))
        segment.events = events

        spiketrains = []
        for name, child_node in node['spikes'].items():
            raise NotImplementedError('Spike objects not yet handled.')
        for name, child_node in node['spiketrains'].items():
            if "SpikeTrain" in name:
                spiketrains.append(self._read_spiketrain(child_node, parent=segment))
        segment.spiketrains = spiketrains

        segment.block = parent
        return segment
示例#6
0
    def read_segment(self, lazy=False, cascade=True,
                     gdf_id_list=None, time_unit=pq.ms, t_start=None,
                     t_stop=None, id_column=0, time_column=1, **args):
        """
        Read a Segment which contains SpikeTrain(s) with specified neuron IDs
        from the GDF data.

        Parameters
        ----------
        lazy : bool, optional, default: False
        cascade : bool, optional, default: True
        gdf_id_list : list or tuple, default: None
            Can be either list of GDF IDs of which to return SpikeTrain(s) or
            a tuple specifying the range (includes boundaries [start, stop])
            of GDF IDs. Must be specified if the GDF file contains neuron
            IDs, the default None then raises an error. Specify an empty
            list [] to retrieve the spike trains of all neurons with at least
            one spike.
        time_unit : Quantity (time), optional, default: quantities.ms
            The time unit of recorded time stamps.
        t_start : Quantity (time), default: None
            Start time of SpikeTrain. t_start must be specified, the default None
            raises an error.
        t_stop : Quantity (time), default: None
            Stop time of SpikeTrain. t_stop must be specified, the default None
            raises an error.
        id_column : int, optional, default: 0
            Column index of neuron IDs.
        time_column : int, optional, default: 1
            Column index of time stamps.

        Returns
        -------
        seg : Segment
            The Segment contains one SpikeTrain for each ID in gdf_id_list.
        """

        if isinstance(gdf_id_list, tuple):
            gdf_id_list = range(gdf_id_list[0], gdf_id_list[1] + 1)

        # __read_spiketrains() needs a list of IDs
        if gdf_id_list is None:
            gdf_id_list = [None]

        # create an empty Segment and fill in the spike trains
        seg = Segment()
        seg.spiketrains = self.__read_spiketrains(gdf_id_list,
                                                  time_unit, t_start,
                                                  t_stop,
                                                  id_column, time_column,
                                                  **args)

        return seg
示例#7
0
def proc_src_condition(rep, filename, ADperiod, side, block):
    '''Get the condition in a src file that has been processed by the official
    matlab function.  See proc_src for details'''

    chx = block.channel_indexes[0]

    stim = rep['stim'].flatten()
    params = [str(res[0]) for res in stim['paramName'][0].flatten()]
    values = [res for res in stim['paramVal'][0].flatten()]
    stim = dict(zip(params, values))
    sweepLen = rep['sweepLen'][0, 0]

    if not len(rep):
        return

    unassignedSpikes = rep['unassignedSpikes'].flatten()
    if len(unassignedSpikes):
        damaIndexes = [res[0, 0] for res in unassignedSpikes['damaIndex']]
        timeStamps = [res[0, 0] for res in unassignedSpikes['timeStamp']]
        spikeunit = [res.flatten() for res in unassignedSpikes['spikes']]
        respWin = np.array([], dtype=np.int32)
        trains = proc_src_condition_unit(spikeunit, sweepLen, side, ADperiod,
                                         respWin, damaIndexes, timeStamps,
                                         filename)
        chx.units[0].spiketrains.extend(trains)
        atrains = [trains]
    else:
        damaIndexes = []
        timeStamps = []
        atrains = []

    clusters = rep['clusters'].flatten()
    if len(clusters):
        IdStrings = [res[0] for res in clusters['IdString']]
        sweepLens = [res[0, 0] for res in clusters['sweepLen']]
        respWins = [res.flatten() for res in clusters['respWin']]
        spikeunits = []
        for cluster in clusters['sweeps']:
            if len(cluster):
                spikes = [res.flatten() for res in
                          cluster['spikes'].flatten()]
            else:
                spikes = []
            spikeunits.append(spikes)
    else:
        IdStrings = []
        sweepLens = []
        respWins = []
        spikeunits = []

    for unit, IdString in zip(chx.units[1:], IdStrings):
        unit.name = str(IdString)

    fullunit = zip(spikeunits, chx.units[1:], sweepLens, respWins)
    for spikeunit, unit, sweepLen, respWin in fullunit:
        trains = proc_src_condition_unit(spikeunit, sweepLen, side, ADperiod,
                                         respWin, damaIndexes, timeStamps,
                                         filename)
        atrains.append(trains)
        unit.spiketrains.extend(trains)

    atrains = zip(*atrains)
    for trains in atrains:
        segment = Segment(file_origin=filename, feature_type=-1,
                          go_by_closest_unit_center=False,
                          include_unit_bounds=False, **stim)
        block.segments.append(segment)
        segment.spiketrains = trains
def proc_src_condition(rep, filename, ADperiod, side, block):
    '''Get the condition in a src file that has been processed by the official
    matlab function.  See proc_src for details'''

    rcg = block.recordingchannelgroups[0]

    stim = rep['stim'].flatten()
    params = [str(res[0]) for res in stim['paramName'][0].flatten()]
    values = [res for res in stim['paramVal'][0].flatten()]
    stim = dict(zip(params, values))
    sweepLen = rep['sweepLen'][0, 0]

    if not len(rep):
        return

    unassignedSpikes = rep['unassignedSpikes'].flatten()
    if len(unassignedSpikes):
        damaIndexes = [res[0, 0] for res in unassignedSpikes['damaIndex']]
        timeStamps = [res[0, 0] for res in unassignedSpikes['timeStamp']]
        spikeunit = [res.flatten() for res in unassignedSpikes['spikes']]
        respWin = np.array([], dtype=np.int32)
        trains = proc_src_condition_unit(spikeunit, sweepLen, side, ADperiod,
                                         respWin, damaIndexes, timeStamps,
                                         filename)
        rcg.units[0].spiketrains.extend(trains)
        atrains = [trains]
    else:
        damaIndexes = []
        timeStamps = []
        atrains = []

    clusters = rep['clusters'].flatten()
    if len(clusters):
        IdStrings = [res[0] for res in clusters['IdString']]
        sweepLens = [res[0, 0] for res in clusters['sweepLen']]
        respWins = [res.flatten() for res in clusters['respWin']]
        spikeunits = []
        for cluster in clusters['sweeps']:
            if len(cluster):
                spikes = [res.flatten() for res in cluster['spikes'].flatten()]
            else:
                spikes = []
            spikeunits.append(spikes)
    else:
        IdStrings = []
        sweepLens = []
        respWins = []
        spikeunits = []

    for unit, IdString in zip(rcg.units[1:], IdStrings):
        unit.name = str(IdString)

    fullunit = zip(spikeunits, rcg.units[1:], sweepLens, respWins)
    for spikeunit, unit, sweepLen, respWin in fullunit:
        trains = proc_src_condition_unit(spikeunit, sweepLen, side, ADperiod,
                                         respWin, damaIndexes, timeStamps,
                                         filename)
        atrains.append(trains)
        unit.spiketrains.extend(trains)

    atrains = zip(*atrains)
    for trains in atrains:
        segment = Segment(file_origin=filename,
                          feature_type=-1,
                          go_by_closest_unit_center=False,
                          include_unit_bounds=False,
                          **stim)
        block.segments.append(segment)
        segment.spiketrains = trains
示例#9
0
def proc_f32(filename):
    """Load an f32 file that has already been processed by the official matlab
    file converter.  That matlab data is saved to an m-file, which is then
    converted to a numpy '.npz' file.  This numpy file is the file actually
    loaded.  This function converts it to a neo block and returns the block.
    This block can be compared to the block produced by BrainwareF32IO to
    make sure BrainwareF32IO is working properly

    block = proc_f32(filename)

    filename: The file name of the numpy file to load.  It should end with
    '*_f32_py?.npz'. This will be converted to a neo 'file_origin' property
    with the value '*.f32', so the filename to compare should fit that pattern.
    'py?' should be 'py2' for the python 2 version of the numpy file or 'py3'
    for the python 3 version of the numpy file.

    example: filename = 'file1_f32_py2.npz'
             f32 file name = 'file1.f32'
    """

    filenameorig = os.path.basename(filename[:-12] + ".f32")

    # create the objects to store other objects
    block = Block(file_origin=filenameorig)
    rcg = RecordingChannelGroup(file_origin=filenameorig)
    rcg.channel_indexes = np.array([], dtype=np.int)
    rcg.channel_names = np.array([], dtype="S")
    unit = Unit(file_origin=filenameorig)

    # load objects into their containers
    block.recordingchannelgroups.append(rcg)
    rcg.units.append(unit)

    try:
        with np.load(filename) as f32obj:
            f32file = f32obj.items()[0][1].flatten()
    except IOError as exc:
        if "as a pickle" in exc.message:
            block.create_many_to_one_relationship()
            return block
        else:
            raise

    sweeplengths = [res[0, 0].tolist() for res in f32file["sweeplength"]]
    stims = [res.flatten().tolist() for res in f32file["stim"]]

    sweeps = [res["spikes"].flatten() for res in f32file["sweep"] if res.size]

    fullf32 = zip(sweeplengths, stims, sweeps)
    for sweeplength, stim, sweep in fullf32:
        for trainpts in sweep:
            if trainpts.size:
                trainpts = trainpts.flatten().astype("float32")
            else:
                trainpts = []

            paramnames = ["Param%s" % i for i in range(len(stim))]
            params = dict(zip(paramnames, stim))
            train = SpikeTrain(trainpts, units=pq.ms, t_start=0, t_stop=sweeplength, file_origin=filenameorig)

            segment = Segment(file_origin=filenameorig, **params)
            segment.spiketrains = [train]
            unit.spiketrains.append(train)
            block.segments.append(segment)

    block.create_many_to_one_relationship()

    return block
def proc_f32(filename):
    '''Load an f32 file that has already been processed by the official matlab
    file converter.  That matlab data is saved to an m-file, which is then
    converted to a numpy '.npz' file.  This numpy file is the file actually
    loaded.  This function converts it to a neo block and returns the block.
    This block can be compared to the block produced by BrainwareF32IO to
    make sure BrainwareF32IO is working properly

    block = proc_f32(filename)

    filename: The file name of the numpy file to load.  It should end with
    '*_f32_py?.npz'. This will be converted to a neo 'file_origin' property
    with the value '*.f32', so the filename to compare should fit that pattern.
    'py?' should be 'py2' for the python 2 version of the numpy file or 'py3'
    for the python 3 version of the numpy file.

    example: filename = 'file1_f32_py2.npz'
             f32 file name = 'file1.f32'
    '''

    filenameorig = os.path.basename(filename[:-12] + '.f32')

    # create the objects to store other objects
    block = Block(file_origin=filenameorig)
    chx = ChannelIndex(file_origin=filenameorig,
                       index=np.array([], dtype=np.int),
                       channel_names=np.array([], dtype='S'))
    unit = Unit(file_origin=filenameorig)

    # load objects into their containers
    block.channel_indexes.append(chx)
    chx.units.append(unit)

    try:
        with np.load(filename) as f32obj:
            f32file = f32obj.items()[0][1].flatten()
    except IOError as exc:
        if 'as a pickle' in exc.message:
            block.create_many_to_one_relationship()
            return block
        else:
            raise

    sweeplengths = [res[0, 0].tolist() for res in f32file['sweeplength']]
    stims = [res.flatten().tolist() for res in f32file['stim']]

    sweeps = [res['spikes'].flatten() for res in f32file['sweep'] if res.size]

    fullf32 = zip(sweeplengths, stims, sweeps)
    for sweeplength, stim, sweep in fullf32:
        for trainpts in sweep:
            if trainpts.size:
                trainpts = trainpts.flatten().astype('float32')
            else:
                trainpts = []

            paramnames = ['Param%s' % i for i in range(len(stim))]
            params = dict(zip(paramnames, stim))
            train = SpikeTrain(trainpts,
                               units=pq.ms,
                               t_start=0,
                               t_stop=sweeplength,
                               file_origin=filenameorig)

            segment = Segment(file_origin=filenameorig, **params)
            segment.spiketrains = [train]
            unit.spiketrains.append(train)
            block.segments.append(segment)

    block.create_many_to_one_relationship()

    return block
示例#11
0
    def read_segment(self, gid_list=None, time_unit=pq.ms, t_start=None,
                     t_stop=None, sampling_period=None, id_column_dat=0,
                     time_column_dat=1, value_columns_dat=2,
                     id_column_gdf=0, time_column_gdf=1, value_types=None,
                     value_units=None, lazy=False):
        """
        Reads a Segment which contains SpikeTrain(s) with specified neuron IDs
        from the GDF data.

        Arguments
        ----------
        gid_list : list, default: None
            A list of GDF IDs of which to return SpikeTrain(s). gid_list must
            be specified if the GDF file contains neuron IDs, the default None
            then raises an error. Specify an empty list [] to retrieve the spike
            trains of all neurons.
        time_unit : Quantity (time), optional, default: quantities.ms
            The time unit of recorded time stamps in DAT as well as GDF files.
        t_start : Quantity (time), optional, default: 0 * pq.ms
            Start time of SpikeTrain.
        t_stop : Quantity (time), default: None
            Stop time of SpikeTrain. t_stop must be specified, the default None
            raises an error.
        sampling_period : Quantity (frequency), optional, default: None
            Sampling period of the recorded data.
        id_column_dat : int, optional, default: 0
            Column index of neuron IDs in the DAT file.
        time_column_dat : int, optional, default: 1
            Column index of time stamps in the DAT file.
        value_columns_dat : int, optional, default: 2
            Column index of the analog values recorded in the DAT file.
        id_column_gdf : int, optional, default: 0
            Column index of neuron IDs in the GDF file.
        time_column_gdf : int, optional, default: 1
            Column index of time stamps in the GDF file.
        value_types : str, optional, default: None
            Nest data type of the analog values recorded, eg.'V_m', 'I', 'g_e'
        value_units : Quantity (amplitude), default: None
            The physical unit of the recorded signal values.
        lazy : bool, optional, default: False

        Returns
        -------
        seg : Segment
            The Segment contains one SpikeTrain and one AnalogSignal for
            each ID in gid_list.
        """
        assert not lazy, 'Do not support lazy'

        if isinstance(gid_list, tuple):
            if gid_list[0] > gid_list[1]:
                raise ValueError('The second entry in gid_list must be '
                                 'greater or equal to the first entry.')
            gid_list = range(gid_list[0], gid_list[1] + 1)

        # __read_xxx() needs a list of IDs
        if gid_list is None:
            gid_list = [None]

        # create an empty Segment
        seg = Segment(file_origin=",".join(self.filenames))
        seg.file_datetime = datetime.fromtimestamp(os.stat(self.filenames[0]).st_mtime)
        # todo: rather than take the first file for the timestamp, we should take the oldest
        #       in practice, there won't be much difference

        # Load analogsignals and attach to Segment
        if 'dat' in self.avail_formats:
            seg.analogsignals = self.__read_analogsignals(
                gid_list,
                time_unit,
                t_start,
                t_stop,
                sampling_period=sampling_period,
                id_column=id_column_dat,
                time_column=time_column_dat,
                value_columns=value_columns_dat,
                value_types=value_types,
                value_units=value_units)
        if 'gdf' in self.avail_formats:
            seg.spiketrains = self.__read_spiketrains(
                gid_list,
                time_unit,
                t_start,
                t_stop,
                id_column=id_column_gdf,
                time_column=time_column_gdf)

        return seg
示例#12
0
    def read_segment(self,
                     gid_list=None,
                     time_unit=pq.ms,
                     t_start=None,
                     t_stop=None,
                     sampling_period=None,
                     id_column_dat=0,
                     time_column_dat=1,
                     value_columns_dat=2,
                     id_column_gdf=0,
                     time_column_gdf=1,
                     value_types=None,
                     value_units=None,
                     lazy=False,
                     cascade=True):
        """
        Reads a Segment which contains SpikeTrain(s) with specified neuron IDs
        from the GDF data.

        Arguments
        ----------
        gid_list : list, default: None
            A list of GDF IDs of which to return SpikeTrain(s). gid_list must
            be specified if the GDF file contains neuron IDs, the default None
            then raises an error. Specify an empty list [] to retrieve the spike
            trains of all neurons.
        time_unit : Quantity (time), optional, default: quantities.ms
            The time unit of recorded time stamps in DAT as well as GDF files.
        t_start : Quantity (time), optional, default: 0 * pq.ms
            Start time of SpikeTrain.
        t_stop : Quantity (time), default: None
            Stop time of SpikeTrain. t_stop must be specified, the default None
            raises an error.
        sampling_period : Quantity (frequency), optional, default: None
            Sampling period of the recorded data.
        id_column_dat : int, optional, default: 0
            Column index of neuron IDs in the DAT file.
        time_column_dat : int, optional, default: 1
            Column index of time stamps in the DAT file.
        value_columns_dat : int, optional, default: 2
            Column index of the analog values recorded in the DAT file.
        id_column_gdf : int, optional, default: 0
            Column index of neuron IDs in the GDF file.
        time_column_gdf : int, optional, default: 1
            Column index of time stamps in the GDF file.
        value_types : str, optional, default: None
            Nest data type of the analog values recorded, eg.'V_m', 'I', 'g_e'
        value_units : Quantity (amplitude), default: None
            The physical unit of the recorded signal values.
        lazy : bool, optional, default: False
        cascade : bool, optional, default: True

        Returns
        -------
        seg : Segment
            The Segment contains one SpikeTrain and one AnalogSignal for
            each ID in gid_list.
        """
        if isinstance(gid_list, tuple):
            if gid_list[0] > gid_list[1]:
                raise ValueError('The second entry in gid_list must be '
                                 'greater or equal to the first entry.')
            gid_list = range(gid_list[0], gid_list[1] + 1)

        # __read_xxx() needs a list of IDs
        if gid_list is None:
            gid_list = [None]

        # create an empty Segment
        seg = Segment(file_origin=",".join(self.filenames))
        seg.file_datetime = datetime.fromtimestamp(
            os.stat(self.filenames[0]).st_mtime)
        # todo: rather than take the first file for the timestamp, we should take the oldest
        #       in practice, there won't be much difference

        if cascade:
            # Load analogsignals and attach to Segment
            if 'dat' in self.avail_formats:
                seg.analogsignals = self.__read_analogsignals(
                    gid_list,
                    time_unit,
                    t_start,
                    t_stop,
                    sampling_period=sampling_period,
                    id_column=id_column_dat,
                    time_column=time_column_dat,
                    value_columns=value_columns_dat,
                    value_types=value_types,
                    value_units=value_units,
                    lazy=lazy)
            if 'gdf' in self.avail_formats:
                seg.spiketrains = self.__read_spiketrains(
                    gid_list,
                    time_unit,
                    t_start,
                    t_stop,
                    id_column=id_column_gdf,
                    time_column=time_column_gdf)

        return seg
示例#13
0
    def test__cut_block_by_epochs(self):
        epoch = Epoch([0.5, 10.0, 25.2] * pq.s,
                      durations=[5.1, 4.8, 5.0] * pq.s,
                      t_start=.1 * pq.s)
        epoch.annotate(epoch_type='a', pick='me')
        epoch.array_annotate(trial_id=[1, 2, 3])

        epoch2 = Epoch([0.6, 9.5, 16.8, 34.1] * pq.s,
                       durations=[4.5, 4.8, 5.0, 5.0] * pq.s,
                       t_start=.1 * pq.s)
        epoch2.annotate(epoch_type='b')
        epoch2.array_annotate(trial_id=[1, 2, 3, 4])

        event = Event(times=[0.5, 10.0, 25.2] * pq.s, t_start=.1 * pq.s)
        event.annotate(event_type='trial start')
        event.array_annotate(trial_id=[1, 2, 3])

        anasig = AnalogSignal(np.arange(50.0) * pq.mV,
                              t_start=.1 * pq.s,
                              sampling_rate=1.0 * pq.Hz)
        irrsig = IrregularlySampledSignal(signal=np.arange(50.0) * pq.mV,
                                          times=anasig.times,
                                          t_start=.1 * pq.s)
        st = SpikeTrain(
            np.arange(0.5, 50, 7) * pq.s,
            t_start=.1 * pq.s,
            t_stop=50.0 * pq.s,
            waveforms=np.array(
                [[[0., 1.], [0.1, 1.1]], [[2., 3.], [2.1, 3.1]],
                 [[4., 5.], [4.1, 5.1]], [[6., 7.], [6.1, 7.1]],
                 [[8., 9.], [8.1, 9.1]], [[12., 13.], [12.1, 13.1]],
                 [[14., 15.], [14.1, 15.1]], [[16., 17.], [16.1, 17.1]]]) *
            pq.mV,
            array_annotations={'spikenum': np.arange(1, 9)})

        seg = Segment()
        seg2 = Segment(name='NoCut')
        seg.epochs = [epoch, epoch2]
        seg.events = [event]
        seg.analogsignals = [anasig]
        seg.irregularlysampledsignals = [irrsig]
        seg.spiketrains = [st]

        block = Block()
        block.segments = [seg, seg2]
        block.create_many_to_one_relationship()

        # test without resetting the time
        cut_block_by_epochs(block, properties={'pick': 'me'})

        assert_neo_object_is_compliant(block)
        self.assertEqual(len(block.segments), 3)

        for epoch_idx in range(len(epoch)):
            self.assertEqual(len(block.segments[epoch_idx].events), 1)
            self.assertEqual(len(block.segments[epoch_idx].spiketrains), 1)
            self.assertEqual(len(block.segments[epoch_idx].analogsignals), 1)
            self.assertEqual(
                len(block.segments[epoch_idx].irregularlysampledsignals), 1)

            if epoch_idx != 0:
                self.assertEqual(len(block.segments[epoch_idx].epochs), 1)
            else:
                self.assertEqual(len(block.segments[epoch_idx].epochs), 2)

            assert_same_attributes(
                block.segments[epoch_idx].spiketrains[0],
                st.time_slice(t_start=epoch.times[epoch_idx],
                              t_stop=epoch.times[epoch_idx] +
                              epoch.durations[epoch_idx]))
            assert_same_attributes(
                block.segments[epoch_idx].analogsignals[0],
                anasig.time_slice(t_start=epoch.times[epoch_idx],
                                  t_stop=epoch.times[epoch_idx] +
                                  epoch.durations[epoch_idx]))
            assert_same_attributes(
                block.segments[epoch_idx].irregularlysampledsignals[0],
                irrsig.time_slice(t_start=epoch.times[epoch_idx],
                                  t_stop=epoch.times[epoch_idx] +
                                  epoch.durations[epoch_idx]))
            assert_same_attributes(
                block.segments[epoch_idx].events[0],
                event.time_slice(t_start=epoch.times[epoch_idx],
                                 t_stop=epoch.times[epoch_idx] +
                                 epoch.durations[epoch_idx]))
        assert_same_attributes(
            block.segments[0].epochs[0],
            epoch.time_slice(t_start=epoch.times[0],
                             t_stop=epoch.times[0] + epoch.durations[0]))
        assert_same_attributes(
            block.segments[0].epochs[1],
            epoch2.time_slice(t_start=epoch.times[0],
                              t_stop=epoch.times[0] + epoch.durations[0]))

        seg = Segment()
        seg2 = Segment(name='NoCut')
        seg.epochs = [epoch, epoch2]
        seg.events = [event]
        seg.analogsignals = [anasig]
        seg.irregularlysampledsignals = [irrsig]
        seg.spiketrains = [st]

        block = Block()
        block.segments = [seg, seg2]
        block.create_many_to_one_relationship()

        # test with resetting the time
        cut_block_by_epochs(block, properties={'pick': 'me'}, reset_time=True)

        assert_neo_object_is_compliant(block)
        self.assertEqual(len(block.segments), 3)

        for epoch_idx in range(len(epoch)):
            self.assertEqual(len(block.segments[epoch_idx].events), 1)
            self.assertEqual(len(block.segments[epoch_idx].spiketrains), 1)
            self.assertEqual(len(block.segments[epoch_idx].analogsignals), 1)
            self.assertEqual(
                len(block.segments[epoch_idx].irregularlysampledsignals), 1)
            if epoch_idx != 0:
                self.assertEqual(len(block.segments[epoch_idx].epochs), 1)
            else:
                self.assertEqual(len(block.segments[epoch_idx].epochs), 2)

            assert_same_attributes(
                block.segments[epoch_idx].spiketrains[0],
                st.time_shift(-epoch.times[epoch_idx]).time_slice(
                    t_start=0 * pq.s, t_stop=epoch.durations[epoch_idx]))

            anasig_target = anasig.time_shift(-epoch.times[epoch_idx])
            anasig_target = anasig_target.time_slice(
                t_start=0 * pq.s, t_stop=epoch.durations[epoch_idx])
            assert_same_attributes(block.segments[epoch_idx].analogsignals[0],
                                   anasig_target)
            irrsig_target = irrsig.time_shift(-epoch.times[epoch_idx])
            irrsig_target = irrsig_target.time_slice(
                t_start=0 * pq.s, t_stop=epoch.durations[epoch_idx])
            assert_same_attributes(
                block.segments[epoch_idx].irregularlysampledsignals[0],
                irrsig_target)
            assert_same_attributes(
                block.segments[epoch_idx].events[0],
                event.time_shift(-epoch.times[epoch_idx]).time_slice(
                    t_start=0 * pq.s, t_stop=epoch.durations[epoch_idx]))

        assert_same_attributes(
            block.segments[0].epochs[0],
            epoch.time_shift(-epoch.times[0]).time_slice(
                t_start=0 * pq.s, t_stop=epoch.durations[0]))
        assert_same_attributes(
            block.segments[0].epochs[1],
            epoch2.time_shift(-epoch.times[0]).time_slice(
                t_start=0 * pq.s, t_stop=epoch.durations[0]))
    def read_block(self,
                   lazy=False,
                   cascade=True,
                   signal_names=None,
                   signal_units=None):
        block = Block(file_origin=self.filename)
        segment = Segment(name="default")
        block.segments.append(segment)
        segment.block = block

        spike_times = defaultdict(list)
        spike_file = self.filename + ".dat"
        print("SPIKEFILE: {}".format(spike_file))
        if os.path.exists(spike_file):
            print("Loading data from {}".format(spike_file))
            with open(spike_file, 'r') as fp:
                for line in fp:
                    if line[0] != '#':
                        entries = line.strip().split()
                        if len(entries) > 1:
                            time = float(entries[0])
                            for id in entries[1:]:
                                spike_times[id].append(time)
                t_stop = float(entries[0])
            min_id = min(map(int, spike_times))
            segment.spiketrains = [
                SpikeTrain(times,
                           t_stop=t_stop,
                           units="ms",
                           id=int(id),
                           source_index=int(id) - min_id)
                for id, times in spike_times.items()
            ]
        signal_files = glob("{}_state.*.dat".format(self.filename))
        print(signal_files)
        for signal_file in signal_files:
            print("Loading data from {}".format(signal_file))
            population = os.path.basename(signal_file).split(".")[1]
            try:
                data = np.loadtxt(signal_file, delimiter=", ")
            except ValueError:
                print("Couldn't load data from file {}".format(signal_file))
                continue
            t_start = data[0, 1]
            ids = data[:, 0]
            unique_ids = np.unique(ids)
            for column in range(2, data.shape[1]):
                if signal_names is None:
                    signal_name = "signal{}".format(column - 2)
                else:
                    signal_name = signal_names[column - 2]
                if signal_units is None:
                    units = "mV"  # seems like a reasonable default
                else:
                    units = signal_units[column - 2]
                signals_by_id = {}
                for id in unique_ids:
                    times = data[ids == id, 1]
                    unique_times, idx = np.unique(
                        times, return_index=True
                    )  # some time points are represented twice
                    signals_by_id[id] = data[ids == id, column][idx]
                    assert signals_by_id[id].shape == signals_by_id[
                        unique_ids[0]].shape
                channel_ids = np.array(list(signals_by_id.keys()))
                sampling_period = unique_times[1] - unique_times[0]
                assert sampling_period != 0.0, sampling_period
                signal = AnalogSignal(np.vstack(signals_by_id.values()).T,
                                      units=units,
                                      t_start=t_start * pq.ms,
                                      sampling_period=sampling_period * pq.ms,
                                      name=signal_name,
                                      population=population)
                #signal.channel_index = ChannelIndex(np.arange(signal.shape[1], int),
                #                                    channel_ids=channel_ids)
                signal.channel_index = channel_ids
                segment.analogsignals.append(signal)

        return block
示例#15
0
    def read_block(self, lazy=False, cascade=True, signal_names=None, signal_units=None):
        block = Block(file_origin=self.filename)
        segment = Segment(name="default")
        block.segments.append(segment)
        segment.block = block

        spike_times = defaultdict(list)
        spike_file = self.filename + ".dat"
        print("SPIKEFILE: {}".format(spike_file))
        if os.path.exists(spike_file):
            print("Loading data from {}".format(spike_file))
            with open(spike_file, 'r') as fp:
                for line in fp:
                    if line[0] != '#':
                        entries = line.strip().split()
                        if len(entries) > 1:
                            time = float(entries[0])
                            for id in entries[1:]:
                                spike_times[id].append(time)
                t_stop = float(entries[0])
            if spike_times:
                min_id = min(map(int, spike_times))
            segment.spiketrains = [SpikeTrain(times, t_stop=t_stop, units="ms",
                                              id=int(id), source_index=int(id) - min_id)
                                   for id, times in spike_times.items()]
        signal_files = glob("{}_state.*.dat".format(self.filename))
        print(signal_files)
        for signal_file in signal_files:
            print("Loading data from {}".format(signal_file))
            population = os.path.basename(signal_file).split(".")[1]
            try:
                data = np.loadtxt(signal_file, delimiter=", ")
            except ValueError:
                print("Couldn't load data from file {}".format(signal_file))
                continue
            t_start = data[0, 1]
            ids = data[:, 0]
            unique_ids = np.unique(ids)
            for column in range(2, data.shape[1]):
                if signal_names is None:
                    signal_name = "signal{}".format(column - 2)
                else:
                    signal_name = signal_names[column - 2]
                if signal_units is None:
                    units = "mV"  # seems like a reasonable default
                else:
                    units = signal_units[column - 2]
                signals_by_id = {}
                for id in unique_ids:
                    times = data[ids==id, 1]
                    unique_times, idx = np.unique(times, return_index=True)  # some time points are represented twice
                    signals_by_id[id] = data[ids==id, column][idx]
                channel_ids = np.array(list(signals_by_id.keys()))
                if len(unique_times) > 1:
                    sampling_period = unique_times[1] - unique_times[0]
                    assert sampling_period != 0.0, sampling_period
                    signal_lengths = np.array([s.size for s in signals_by_id.values()])
                    min_length = signal_lengths.min()
                    if not (signal_lengths == signal_lengths[0]).all():
                        print("Warning: signals have different sizes: min={}, max={}".format(min_length,
                                                                                             signal_lengths.max()))
                        print("Truncating to length {}".format(min_length))
                    signal = AnalogSignal(np.vstack([s[:min_length] for s in signals_by_id.values()]).T,
                                          units=units,
                                          t_start=t_start * pq.ms,
                                          sampling_period=sampling_period*pq.ms,
                                          name=signal_name,
                                          population=population)
                    #signal.channel_index = ChannelIndex(np.arange(signal.shape[1], int),
                    #                                    channel_ids=channel_ids)
                    signal.channel_index = channel_ids
                    segment.analogsignals.append(signal)

        return block