Exemplo n.º 1
0
def test_spike_argmax_indexed():
    train1 = SpikeTrain([3, 4] * s, t_stop=10)
    train2 = SpikeTrain([5, 6] * s, t_stop=10)
    trains = np.array([train1, train2])
    out = v.spike_argmax(trains)
    expected = np.array([1, 0])
    assert np.array_equal(out, expected)
Exemplo n.º 2
0
    def test_spiketrain_write(self):
        block = Block()
        seg = Segment()
        block.segments.append(seg)

        spiketrain = SpikeTrain(times=[3, 4, 5] * pq.s,
                                t_stop=10.0,
                                name="spikes!",
                                description="sssssspikes")
        seg.spiketrains.append(spiketrain)
        self.write_and_compare([block])

        waveforms = self.rquant((3, 5, 10), pq.mV)
        spiketrain = SpikeTrain(times=[1, 1.1, 1.2] * pq.ms,
                                t_stop=1.5 * pq.s,
                                name="spikes with wf",
                                description="spikes for waveform test",
                                waveforms=waveforms)

        seg.spiketrains.append(spiketrain)
        self.write_and_compare([block])

        spiketrain.left_sweep = np.random.random(10) * pq.ms
        self.write_and_compare([block])

        spiketrain.left_sweep = pq.Quantity(-10, "ms")
        self.write_and_compare([block])
Exemplo n.º 3
0
def test_spike_argmax_zeros():
    train1 = SpikeTrain([] * s, t_stop=10)
    train2 = SpikeTrain([] * s, t_stop=10)
    trains = np.array([train1, train2])
    out = v.spike_argmax(trains)
    assert out.size == len(trains)
    assert out.sum() == 1
Exemplo n.º 4
0
def test_spike_argmax_randomised():
    train1 = SpikeTrain([3, 4] * s, t_stop=10)
    train2 = SpikeTrain([5, 6] * s, t_stop=10)
    trains = np.array([train1, train2])
    out = v.spike_argmax_randomise(trains)
    assert out.size == len(trains)
    assert out.sum() == 1
Exemplo n.º 5
0
def test_compute_orientation_tuning():
    from neo.core import SpikeTrain
    import quantities as pq
    from exana.stimulus.tools import (make_orientation_trials,
                                      compute_orientation_tuning)

    trials = [
        SpikeTrain(np.arange(0, 10, 2) * pq.s,
                   t_stop=10 * pq.s,
                   orient=315. * pq.deg),
        SpikeTrain(np.arange(0, 10, 0.5) * pq.s,
                   t_stop=10 * pq.s,
                   orient=np.pi / 3 * pq.rad),
        SpikeTrain(np.arange(0, 10, 1) * pq.s,
                   t_stop=10 * pq.s,
                   orient=0 * pq.deg),
        SpikeTrain(np.arange(0, 10, 0.3) * pq.s,
                   t_stop=10 * pq.s,
                   orient=np.pi / 3 * pq.rad)
    ]
    sorted_orients = np.array(
        [0, (np.pi / 3 * pq.rad).rescale(pq.deg) / pq.deg, 315]) * pq.deg
    rates_e = np.array([1., 2.7, 0.5]) / pq.s

    trials = make_orientation_trials(trials)
    rates, orients = compute_orientation_tuning(trials)
    assert ((rates == rates_e).all())
    assert (rates.units == rates_e.units)
    assert ((orients == sorted_orients).all())
    assert (orients.units == sorted_orients.units)
Exemplo n.º 6
0
def test_make_orientation_trials():
    from neo.core import SpikeTrain
    from exana.stimulus.tools import (make_orientation_trials,
                                      _convert_string_to_quantity_scalar)

    trials = [
        SpikeTrain(np.arange(0, 10, 2) * pq.s,
                   t_stop=10 * pq.s,
                   orient=315. * pq.deg),
        SpikeTrain(np.arange(0, 10, 0.5) * pq.s,
                   t_stop=10 * pq.s,
                   orient=np.pi / 3 * pq.rad),
        SpikeTrain(np.arange(0, 10, 1) * pq.s,
                   t_stop=10 * pq.s,
                   orient=0 * pq.deg),
        SpikeTrain(np.arange(0, 10, 0.3) * pq.s,
                   t_stop=10 * pq.s,
                   orient=np.pi / 3 * pq.rad)
    ]

    sorted_trials = [[trials[2]], [trials[1], trials[3]], [trials[0]]]
    sorted_orients = [
        0 * pq.deg, (np.pi / 3 * pq.rad).rescale(pq.deg), 315 * pq.deg
    ]
    orient_trials = make_orientation_trials(trials, unit=pq.deg)

    for (key, value), trial, orient in zip(orient_trials.items(),
                                           sorted_trials, sorted_orients):
        key = _convert_string_to_quantity_scalar(key)
        assert (key == orient.magnitude)
        for t, st in zip(value, trial):
            assert ((t == st).all())
            assert (t.t_start == st.t_start)
            assert (t.t_stop == st.t_stop)
            assert (t.annotations["orient"] == orient)
Exemplo n.º 7
0
def test_spike_argmax_regular():
    train1 = SpikeTrain([2, 20] * s, t_stop=100)
    train2 = SpikeTrain([31, 40, 61] * s, t_stop=100)
    train3 = SpikeTrain([1, 4] * s, t_stop=100)
    train4 = SpikeTrain([1, 47] * s, t_stop=100)
    trains = np.array([train1, train2, train3, train4])
    out = v.spike_argmax(trains)
    assert np.array_equal(out, np.array([0, 1, 0, 0]))
Exemplo n.º 8
0
 def _extract_spikes(self, data, metadata, channel_index, lazy):
     spiketrain = None
     if lazy:
         if channel_index in data[:, 1]:
             spiketrain = SpikeTrain([], units=pq.ms, t_stop=0.0)
             spiketrain.lazy_shape = None
     else:
         spike_times = self._extract_array(data, channel_index)
         if len(spike_times) > 0:
             spiketrain = SpikeTrain(spike_times, units=pq.ms, t_stop=spike_times.max())
     if spiketrain is not None:
         spiketrain.annotate(label=metadata["label"],
                             channel_index=channel_index,
                             dt=metadata["dt"])
         return spiketrain
Exemplo n.º 9
0
 def test_annotations(self):
     self.testfilename = self.get_filename_path('nixio_fr_ann.nix')
     with NixIO(filename=self.testfilename, mode='ow') as io:
         annotations = {'my_custom_annotation': 'hello block'}
         bl = Block(**annotations)
         annotations = {'something': 'hello hello000'}
         seg = Segment(**annotations)
         an =AnalogSignal([[1, 2, 3], [4, 5, 6]], units='V',
                                     sampling_rate=1*pq.Hz)
         an.annotations['ansigrandom'] = 'hello chars'
         sp = SpikeTrain([3, 4, 5]* s, t_stop=10.0)
         sp.annotations['railway'] = 'hello train'
         ev = Event(np.arange(0, 30, 10)*pq.Hz,
                    labels=np.array(['trig0', 'trig1', 'trig2'], dtype='S'))
         ev.annotations['venue'] = 'hello event'
         ev2 = Event(np.arange(0, 30, 10) * pq.Hz,
                    labels=np.array(['trig0', 'trig1', 'trig2'], dtype='S'))
         ev2.annotations['evven'] = 'hello ev'
         seg.spiketrains.append(sp)
         seg.events.append(ev)
         seg.events.append(ev2)
         seg.analogsignals.append(an)
         bl.segments.append(seg)
         io.write_block(bl)
         io.close()
     with NixIOfr(filename=self.testfilename) as frio:
         frbl = frio.read_block()
         assert 'my_custom_annotation' in frbl.annotations
         assert 'something' in frbl.segments[0].annotations
         # assert 'ansigrandom' in frbl.segments[0].analogsignals[0].annotations
         assert 'railway' in frbl.segments[0].spiketrains[0].annotations
         assert 'venue' in frbl.segments[0].events[0].annotations
         assert 'evven' in frbl.segments[0].events[1].annotations
     os.remove(self.testfilename)
 def _handle_processing_group(self, block):
     # todo: handle other modules than Units
     units_group = self._file.get('processing/Units/UnitTimes')
     segment_map = dict(
         (segment.name, segment) for segment in block.segments)
     for name, group in units_group.items():
         if name == 'unit_list':
             pass  # todo
         else:
             segment_name = group['source'].value
             #desc = group['unit_description'].value  # use this to store Neo Unit id?
             segment = segment_map[segment_name]
             if self._lazy:
                 times = np.array(())
                 lazy_shape = group['times'].shape
             else:
                 times = group['times'].value
             spiketrain = SpikeTrain(
                 times,
                 units=pq.second,
                 t_stop=group['t_stop'].value * pq.second
             )  # todo: this is a custom Neo value, general NWB files will not have this - use segment.t_stop instead in that case?
             if self._lazy:
                 spiketrain.lazy_shape = lazy_shape
             spiketrain.segment = segment
             segment.spiketrains.append(spiketrain)
Exemplo n.º 11
0
def output_fire_freq(spiketime_dict,isi_binsize,isibins,max_time):
    import elephant
    from neo.core import AnalogSignal,SpikeTrain
    import quantities as q
    ratebins=np.arange(0,np.ceil(max_time),isi_binsize)
    #spike_rate: across entire time
    spike_rate={key:np.zeros((len(spike_set),len(ratebins))) for key,spike_set in spiketime_dict.items()}
    spike_rate_vs_time_mean={};spike_rate_vs_time_std={}
    #spike_freq: segmented into pre,stim,post
    spike_freq={key:{} for key in spiketime_dict.keys()}
    spike_freq_mean={key:{} for key in spiketime_dict.keys()}
    spike_freq_std={key:{} for key in spiketime_dict.keys()}
    for key,spike_set in spiketime_dict.items(): #iterate over different stimulation conditions
        for i in range(len(spike_set)): #iterate over trials
            train=SpikeTrain(spike_set[i]*q.s,t_stop=np.ceil(max_time)*q.s)
            spike_rate[key][i]=elephant.statistics.instantaneous_rate(train,isi_binsize*q.s).magnitude[:,0]#/len(spike_set)
    for key,rate_set in spike_rate.items(): #separate out the spike_rate into pre, post, and stimulation time frames
        for pre_post,binlist in isibins.items():
            binmin_idx=np.abs(ratebins-binlist[0]).argmin()
            binmax_idx=np.abs(ratebins-(binlist[-1]+isi_binsize)).argmin()
            spike_freq[key][pre_post]=spike_rate[key][:,binmin_idx:binmax_idx]
        spike_rate_vs_time_mean[key]=np.mean(spike_rate[key],axis=0) #average across trials
        spike_rate_vs_time_std[key]=np.std(spike_rate[key],axis=0) #std across trials
        for pre_post,binlist in isibins.items():
            spike_freq_mean[key][pre_post]=np.mean(spike_freq[key][pre_post],axis=0)
            spike_freq_std[key][pre_post]=np.std(spike_freq[key][pre_post],axis=0)
    return spike_freq_mean,spike_freq_std,spike_rate_vs_time_mean,spike_rate_vs_time_std,ratebins
Exemplo n.º 12
0
    def read_spiketrain(fh, block_id, array_id):
        nix_block = fh.handle.blocks[block_id]
        nix_da = nix_block.data_arrays[array_id]

        params = {
            'times': nix_da[:],  # TODO think about lazy data loading
            'dtype': nix_da.dtype,
            't_start': Reader.Help.read_quantity(nix_da.metadata, 't_start'),
            't_stop': Reader.Help.read_quantity(nix_da.metadata, 't_stop')
        }

        name = Reader.Help.get_obj_neo_name(nix_da)
        if name:
            params['name'] = name

        if 'left_sweep' in nix_da.metadata:
            params['left_sweep'] = Reader.Help.read_quantity(nix_da.metadata, 'left_sweep')

        if len(nix_da.dimensions) > 0:
            s_dim = nix_da.dimensions[0]
            params['sampling_rate'] = s_dim.sampling_interval * getattr(pq, s_dim.unit)

        if nix_da.unit:
            params['units'] = nix_da.unit

        st = SpikeTrain(**params)

        for key, value in Reader.Help.read_attributes(nix_da.metadata, 'spiketrain').items():
            setattr(st, key, value)

        st.annotations = Reader.Help.read_annotations(nix_da.metadata, 'spiketrain')

        return st
Exemplo n.º 13
0
 def read_segment(
     self,
     lazy=False,
     delimiter="\t",
     t_start=0. * pq.s,
     unit=pq.s,
 ):
     """
     delimiter is the columns delimiter in file  "\t" or one space or two space or "," or ";"
     t_start is the time start of all spiketrain 0 by default. unit is the unit of spike times,
     can be a str or directly a quantity.
     """
     assert not lazy, "Do not support lazy"
     unit = pq.Quantity(1, unit)
     seg = Segment(file_origin=os.path.basename(self.filename))
     f = open(self.filename, "Ur")
     for i, line in enumerate(f):
         alldata = line[:-1].split(delimiter)
         if alldata[-1] == "":
             alldata = alldata[:-1]
         if alldata[0] == "":
             alldata = alldata[1:]
         spike_times = np.array(alldata).astype("f")
         t_stop = spike_times.max() * unit
         sptr = SpikeTrain(spike_times * unit,
                           t_start=t_start,
                           t_stop=t_stop)
         sptr.annotate(channel_index=i)
         seg.spiketrains.append(sptr)
     f.close()
     seg.create_many_to_one_relationship()
     return seg
Exemplo n.º 14
0
 def load(self, time_slice=None, strict_slicing=True):
     """
     Load SpikeTrainProxy args:
         :param time_slice: None or tuple of the time slice expressed with quantities.
                         None is the entire spike train.
         :param strict_slicing: True by default.
             Control if an error is raised or not when one of the time_slice members
             (t_start or t_stop) is outside the real time range of the segment.
     """
     interval = None
     if time_slice:
         interval = (float(t)
                     for t in time_slice)  # convert from quantities
     spike_times = self._units_table.get_unit_spike_times(
         self.id, in_interval=interval)
     return SpikeTrain(
         spike_times * self.units,
         self.t_stop,
         units=self.units,
         # sampling_rate=array(1.) * Hz,
         t_start=self.t_start,
         # waveforms=None,
         # left_sweep=None,
         name=self.name,
         # file_origin=None,
         # description=None,
         # array_annotations=None,
         **self.annotations)
Exemplo n.º 15
0
    def test__construct_subsegment_by_unit(self):
        nb_seg = 3
        nb_unit = 7
        unit_with_sig = [0, 2, 5]
        signal_types = ['Vm', 'Conductances']
        sig_len = 100

        #recordingchannelgroups
        rcgs = [ RecordingChannelGroup(name = 'Vm', channel_indexes = unit_with_sig),
                        RecordingChannelGroup(name = 'Conductance', channel_indexes = unit_with_sig), ]

        # Unit
        all_unit = [ ]
        for u in range(nb_unit):
            un = Unit(name = 'Unit #%d' % u, channel_indexes = [u])
            all_unit.append(un)

        bl = Block()
        for s in range(nb_seg):
            seg = Segment(name = 'Simulation %s' % s)
            for j in range(nb_unit):
                st = SpikeTrain([1, 2, 3], units = 'ms', t_start = 0., t_stop = 10)
                st.unit = all_unit[j]

            for t in signal_types:
                anasigarr = AnalogSignalArray( np.zeros((sig_len, len(unit_with_sig)) ), units = 'nA',
                                sampling_rate = 1000.*pq.Hz, channel_indexes = unit_with_sig )
                seg.analogsignalarrays.append(anasigarr)

        # what you want
        subseg = seg.construct_subsegment_by_unit(all_unit[:4])
Exemplo n.º 16
0
def train_to_neo(train, framerate=24):
  ''' convert a single spike train to Neo SpikeTrain '''
  times = np.nonzero(train)[0]
  times = times / framerate * pq.s
  t_stop = train.shape[-1] / framerate * pq.s
  spike_train = SpikeTrain(times=times, units=pq.s, t_stop=t_stop)
  return spike_train
Exemplo n.º 17
0
 def remove_short_ISIs(self, clusterID, ISIThreshold):
     '''
     Removes second spike in cases where ISI is less than a given threshold.
     WARNING: Modifies cluster in-place.
     :param clusterID: ID of cluster to be modified
     :param ISIThreshold: spikes with ISIs less than threshold will be removed
     :return: N/A
     '''
     # TODO: if we have an ISI below threshold, do we have to remove BOTH spikes?
     # if duplicate spike times due to cluster misalignment, then no...
     # if contamination then yes (we do not know which spike is `correct`)
     spikeTimes = self.clusters[clusterID].spiketrains[0].times.magnitude
     dt = np.zeros(len(spikeTimes))
     dt[:-1] = np.diff(spikeTimes)
     dt[-1] = 1e6
     duplicateSpikeDelays = spikeTimes[np.where(np.abs(dt) <= ISIThreshold)]
     keepSpikeTimes = spikeTimes[np.where(np.abs(dt) > ISIThreshold)]
     keepAmplitudes = self.clusters[clusterID].templateAmplitudes[np.where(
         np.abs(dt) > ISIThreshold)]
     spikeTrain = SpikeTrain(keepSpikeTimes,
                             units='sec',
                             t_stop=max(spikeTimes),
                             t_start=min(0, min(spikeTimes)))
     self.clusters[clusterID].spiketrains[0] = spikeTrain
     self.clusters[clusterID].templateAmplitudes = keepAmplitudes
     self.clusters[clusterID].mergedDuplicates = True
     self.clusters[clusterID].nrMergedDuplicateSpikes = len(
         duplicateSpikeDelays)
Exemplo n.º 18
0
 def __init__(self,
              clusterID,
              group,
              spikeTimes,
              waveForm=None,
              shank=None,
              maxChannel=None,
              coordinates=None,
              firingRate=None):
     super(Unit, self).__init__(name=str(clusterID))
     self.clusterID = clusterID
     self.group = group
     if hasattr(spikeTimes, 'times'):
         self.spiketrains.append(spikeTimes)
     else:
         newSpiketrain = SpikeTrain(spikeTimes,
                                    units='sec',
                                    t_stop=max(spikeTimes))
         self.spiketrains.append(newSpiketrain)
     self.waveForm = waveForm
     self.shank = shank
     self.maxChannel = maxChannel
     self.coordinates = coordinates
     self.firingRate = firingRate
     self.mergedDuplicates = False
     self.nrMergedDuplicateSpikes = 0
Exemplo n.º 19
0
def proc_src_condition_unit_repetition(sweep, damaIndex, timeStamp, sweepLen,
                                       side, ADperiod, respWin, filename):
    '''Get the repetion for a unit in a condition in a src file that has been
    processed by the official matlab function.  See proc_src for details'''
    damaIndex = damaIndex.astype('int32')
    if len(sweep):
        times = np.array([res[0, 0] for res in sweep['time']])
        shapes = np.concatenate([res.flatten()[np.newaxis][np.newaxis] for res
                                 in sweep['shape']], axis=0)
        trig2 = np.array([res[0, 0] for res in sweep['trig2']])
    else:
        times = np.array([])
        shapes = np.array([[[]]])
        trig2 = np.array([])

    times = pq.Quantity(times, units=pq.ms, dtype=np.float32)
    t_start = pq.Quantity(0, units=pq.ms, dtype=np.float32)
    t_stop = pq.Quantity(sweepLen, units=pq.ms, dtype=np.float32)
    trig2 = pq.Quantity(trig2, units=pq.ms, dtype=np.uint8)
    waveforms = pq.Quantity(shapes, dtype=np.int8, units=pq.mV)
    sampling_period = pq.Quantity(ADperiod, units=pq.us)

    train = SpikeTrain(times=times, t_start=t_start, t_stop=t_stop,
                       trig2=trig2, dtype=np.float32, timestamp=timeStamp,
                       dama_index=damaIndex, side=side, copy=True,
                       respwin=respWin, waveforms=waveforms,
                       file_origin=filename)
    train.annotations['side'] = side
    train.sampling_period = sampling_period
    return train
Exemplo n.º 20
0
    def __save_segment(self):
        '''
        Write the segment to the Block if it exists
        '''
        # if this is the beginning of the first condition, then we don't want
        # to save, so exit
        # but set __seg from None to False so we know next time to create a
        # segment even if there are no spike in the condition
        if self.__seg is None:
            self.__seg = False
            return

        if not self.__seg:
            # create dummy values if there are no SpikeTrains in this condition
            self.__seg = Segment(file_origin=self._filename, **self.__params)
            self.__spiketimes = []

        times = pq.Quantity(self.__spiketimes, dtype=np.float32, units=pq.ms)
        train = SpikeTrain(times,
                           t_start=0 * pq.ms,
                           t_stop=self.__t_stop * pq.ms,
                           file_origin=self._filename)

        self.__seg.spiketrains = [train]
        self.__unit.spiketrains.append(train)
        self._blk.segments.append(self.__seg)

        # set an empty segment
        # from now on, we need to set __seg to False rather than None so
        # that if there is a condition with no SpikeTrains we know
        # to create an empty Segment
        self.__seg = False
Exemplo n.º 21
0
    def read_spiketrain(self ,
                                            # the 2 first key arguments are imposed by neo.io API
                                            lazy = False,
                                            cascade = True,

                                                segment_duration = 15.,
                                                t_start = -1,
                                                channel_index = 0,
                                                ):
        """
        With this IO SpikeTrain can e acces directly with its channel number
        """
        # There are 2 possibles behaviour for a SpikeTrain
        # holding many Spike instance or directly holding spike times
        # we choose here the first :
        if not HAVE_SCIPY:
            raise SCIPY_ERR

        num_spike_by_spiketrain = 40
        sr = 10000.

        if lazy:
            times = [ ]
        else:
            times = (np.random.rand(num_spike_by_spiketrain)*segment_duration +
                     t_start)

        # create a spiketrain
        spiketr = SpikeTrain(times, t_start = t_start*pq.s, t_stop = (t_start+segment_duration)*pq.s ,
                                            units = pq.s,
                                            name = 'it is a spiketrain from exampleio',
                                            )

        if lazy:
            # we add the attribute lazy_shape with the size if loaded
            spiketr.lazy_shape = (num_spike_by_spiketrain,)

        # ours spiketrains also hold the waveforms:

        # 1 generate a fake spike shape (2d array if trodness >1)
        w1 = -stats.nct.pdf(np.arange(11,60,4), 5,20)[::-1]/3.
        w2 = stats.nct.pdf(np.arange(11,60,2), 5,20)
        w = np.r_[ w1 , w2 ]
        w = -w/max(w)

        if not lazy:
            # in the neo API the waveforms attr is 3 D in case tetrode
            # in our case it is mono electrode so dim 1 is size 1
            waveforms  = np.tile( w[np.newaxis,np.newaxis,:], ( num_spike_by_spiketrain ,1, 1) )
            waveforms *=  np.random.randn(*waveforms.shape)/6+1
            spiketr.waveforms = waveforms*pq.mV
            spiketr.sampling_rate = sr * pq.Hz
            spiketr.left_sweep = 1.5* pq.s

        # for attributes out of neo you can annotate
        spiketr.annotate(channel_index = channel_index)

        return spiketr
Exemplo n.º 22
0
    def test_multiref_write(self):
        blk = Block("blk1")
        signal = AnalogSignal(name="sig1",
                              signal=[0, 1, 2],
                              units="mV",
                              sampling_period=pq.Quantity(1, "ms"))
        othersignal = IrregularlySampledSignal(name="i1",
                                               signal=[0, 0, 0],
                                               units="mV",
                                               times=[1, 2, 3],
                                               time_units="ms")
        event = Event(name="Evee", times=[0.3, 0.42], units="year")
        epoch = Epoch(name="epoche",
                      times=[0.1, 0.2] * pq.min,
                      durations=[0.5, 0.5] * pq.min)
        st = SpikeTrain(name="the train of spikes",
                        times=[0.1, 0.2, 10.3],
                        t_stop=11,
                        units="us")

        for idx in range(3):
            segname = "seg" + str(idx)
            seg = Segment(segname)
            blk.segments.append(seg)
            seg.analogsignals.append(signal)
            seg.irregularlysampledsignals.append(othersignal)
            seg.events.append(event)
            seg.epochs.append(epoch)
            seg.spiketrains.append(st)

        chidx = ChannelIndex([10, 20, 29])
        seg = blk.segments[0]
        st = SpikeTrain(name="choochoo",
                        times=[10, 11, 80],
                        t_stop=1000,
                        units="s")
        seg.spiketrains.append(st)
        blk.channel_indexes.append(chidx)
        for idx in range(6):
            unit = Unit("unit" + str(idx))
            chidx.units.append(unit)
            unit.spiketrains.append(st)

        self.writer.write_block(blk)
        self.compare_blocks([blk], self.reader.blocks)
Exemplo n.º 23
0
 def test_read_spiketrain_using_eager(self):
     io = self.io_cls(self.test_file)
     st3 = io.read_spiketrain(lazy=False, channel_index=3)
     self.assertIsInstance(st3, SpikeTrain)
     assert_arrays_equal(st3,
                         SpikeTrain(np.arange(3, 104, dtype=float),
                                    t_start=0*pq.s,
                                    t_stop=104*pq.s,
                                    units=pq.ms))
Exemplo n.º 24
0
def find_peak(seg, varname):
    for i, sig in enumerate(seg.analogsignals):
        if sig.name in varname:
            print("finding peaks for "+sig.name)
            p = ssp.find_peaks_cwt(sig.flatten(), np.arange(100, 550),
                                   min_length=10)
            peaks = p / sig.sampling_rate + sig.t_start
            peaks = SpikeTrain(times=peaks, t_start=sig.t_start,
                               t_stop=sig.t_stop, name=sig.name)
            seg.spiketrains.append(peaks)
Exemplo n.º 25
0
 def test_read_segment_containing_spiketrains_using_eager_cascade(self):
     io = self.io_cls(self.test_file)
     segment = io.read_segment(lazy=False, cascade=True)
     self.assertIsInstance(segment, Segment)
     self.assertEqual(len(segment.spiketrains), NCELLS)
     st0 = segment.spiketrains[0]
     self.assertIsInstance(st0, SpikeTrain)
     assert_arrays_equal(st0,
                         SpikeTrain(np.arange(0, 101, dtype=float),
                                    t_start=0*pq.s,
                                    t_stop=101*pq.ms,
                                    units=pq.ms))
     st4 = segment.spiketrains[4]
     self.assertIsInstance(st4, SpikeTrain)
     assert_arrays_equal(st4,
                         SpikeTrain(np.arange(4, 105, dtype=float),
                                    t_start=0*pq.s,
                                    t_stop=105*pq.ms,
                                    units=pq.ms))
Exemplo n.º 26
0
    def load(self, time_slice=None, strict_slicing=True,
                    magnitude_mode='rescaled', load_waveforms=False):
        '''
        *Args*:
            :time_slice: None or tuple of the time slice expressed with quantities.
                            None is the entire signal.
            :strict_slicing: True by default.
                 Control if an error is raise or not when one of  time_slice
                 member (t_start or t_stop) is outside the real time range of the segment.
            :magnitude_mode: 'rescaled' or 'raw'.
            :load_waveforms: bool load waveforms or not.
        '''

        t_start, t_stop = consolidate_time_slice(time_slice, self.t_start,
                                                                    self.t_stop, strict_slicing)
        _t_start, _t_stop = prepare_time_slice(time_slice)

        spike_timestamps = self._rawio.get_spike_timestamps(block_index=self._block_index,
                        seg_index=self._seg_index, unit_index=self._unit_index, t_start=_t_start,
                        t_stop=_t_stop)

        if magnitude_mode == 'raw':
            # we must modify a bit the neo.rawio interface to also read the spike_timestamps
            # underlying clock wich is not always same as sigs
            raise(NotImplementedError)
        elif magnitude_mode == 'rescaled':
            dtype = 'float64'
            spike_times = self._rawio.rescale_spike_timestamp(spike_timestamps, dtype=dtype)
            units = 's'

        if load_waveforms:
            assert self.sampling_rate is not None, 'Do not have waveforms'

            raw_wfs = self._rawio.get_spike_raw_waveforms(block_index=self._block_index,
                seg_index=self._seg_index, unit_index=self._unit_index,
                            t_start=_t_start, t_stop=_t_stop)
            if magnitude_mode == 'rescaled':
                float_wfs = self._rawio.rescale_waveforms_to_float(raw_wfs,
                                dtype='float32', unit_index=self._unit_index)
                waveforms = pq.Quantity(float_wfs, units=self._wf_units,
                            dtype='float32', copy=False)
            elif magnitude_mode == 'raw':
                # could code also CompundUnit here but it is over killed
                # so we used dimentionless
                waveforms = pq.Quantity(raw_wfs, units='',
                            dtype=raw_wfs.dtype, copy=False)
        else:
            waveforms = None

        sptr = SpikeTrain(spike_times, t_stop, units=units, dtype=dtype,
                t_start=t_start, copy=False, sampling_rate=self.sampling_rate,
                waveforms=waveforms, left_sweep=self.left_sweep, name=self.name,
                file_origin=self.file_origin, description=self.description, **self.annotations)

        return sptr
Exemplo n.º 27
0
    def test_no_segment_write(self):
        # Tests storing AnalogSignal, IrregularlySampledSignal, and SpikeTrain
        # objects in the secondary (ChannelIndex) substructure without them
        # being attached to a Segment.
        blk = Block("segmentless block")
        signal = AnalogSignal(name="sig1",
                              signal=[0, 1, 2],
                              units="mV",
                              sampling_period=pq.Quantity(1, "ms"))
        othersignal = IrregularlySampledSignal(name="i1",
                                               signal=[0, 0, 0],
                                               units="mV",
                                               times=[1, 2, 3],
                                               time_units="ms")
        sta = SpikeTrain(name="the train of spikes",
                         times=[0.1, 0.2, 10.3],
                         t_stop=11,
                         units="us")
        stb = SpikeTrain(name="the train of spikes b",
                         times=[1.1, 2.2, 10.1],
                         t_stop=100,
                         units="ms")

        chidx = ChannelIndex([8, 13, 21])
        blk.channel_indexes.append(chidx)
        chidx.analogsignals.append(signal)
        chidx.irregularlysampledsignals.append(othersignal)

        unit = Unit()
        chidx.units.append(unit)
        unit.spiketrains.extend([sta, stb])
        self.writer.write_block(blk)

        self.compare_blocks([blk], self.reader.blocks)

        self.writer.close()
        reader = NixIO(self.filename, "ro")
        blk = reader.read_block(neoname="segmentless block")
        chx = blk.channel_indexes[0]
        self.assertEqual(len(chx.analogsignals), 1)
        self.assertEqual(len(chx.irregularlysampledsignals), 1)
        self.assertEqual(len(chx.units[0].spiketrains), 2)
Exemplo n.º 28
0
 def _read_spiketrain(self, node, parent):
     attributes = self._get_standard_attributes(node)
     t_start = self._get_quantity(node["t_start"])
     t_stop = self._get_quantity(node["t_stop"])
     # todo: handle sampling_rate, waveforms, left_sweep
     spiketrain = SpikeTrain(self._get_quantity(node["times"]),
                             t_start=t_start, t_stop=t_stop,
                             **attributes)
     spiketrain.segment = parent
     self.object_refs[node.attrs["object_ref"]] = spiketrain
     return spiketrain
Exemplo n.º 29
0
def spiketrain_from_raw(channel: AnalogSignal, threshold: Quantity) -> SpikeTrain:
    assert channel.signal.shape[1] == 1
    spikes = find_spikes(channel, threshold)
    signal = channel.signal.squeeze()
    times = np.array([channel_index_to_time(channel, start) for start, _ in spikes]) * channel.sampling_period.units
    waveforms = np.array([[signal[start:stop]] for start, stop in spikes]) * channel.units
    name = f"{channel.name}#{threshold}"
    result = SpikeTrain(name=name, times=times, units=channel.units, t_start=channel.t_start, t_stop=channel.t_stop,
                        waveforms=waveforms, sampling_rate=channel.sampling_rate, sort=True)
    result.annotate(from_channel=channel.name, threshold=threshold)
    return result
Exemplo n.º 30
0
 def _extract_spikes(self, data, metadata, channel_index):
     spiketrain = None
     spike_times = self._extract_array(data, channel_index)
     if len(spike_times) > 0:
         spiketrain = SpikeTrain(spike_times,
                                 units=pq.ms,
                                 t_stop=spike_times.max())
         spiketrain.annotate(label=metadata["label"],
                             channel_index=channel_index,
                             dt=metadata["dt"])
         return spiketrain