def test_spike_argmax_indexed(): train1 = SpikeTrain([3, 4] * s, t_stop=10) train2 = SpikeTrain([5, 6] * s, t_stop=10) trains = np.array([train1, train2]) out = v.spike_argmax(trains) expected = np.array([1, 0]) assert np.array_equal(out, expected)
def test_spiketrain_write(self): block = Block() seg = Segment() block.segments.append(seg) spiketrain = SpikeTrain(times=[3, 4, 5] * pq.s, t_stop=10.0, name="spikes!", description="sssssspikes") seg.spiketrains.append(spiketrain) self.write_and_compare([block]) waveforms = self.rquant((3, 5, 10), pq.mV) spiketrain = SpikeTrain(times=[1, 1.1, 1.2] * pq.ms, t_stop=1.5 * pq.s, name="spikes with wf", description="spikes for waveform test", waveforms=waveforms) seg.spiketrains.append(spiketrain) self.write_and_compare([block]) spiketrain.left_sweep = np.random.random(10) * pq.ms self.write_and_compare([block]) spiketrain.left_sweep = pq.Quantity(-10, "ms") self.write_and_compare([block])
def test_spike_argmax_zeros(): train1 = SpikeTrain([] * s, t_stop=10) train2 = SpikeTrain([] * s, t_stop=10) trains = np.array([train1, train2]) out = v.spike_argmax(trains) assert out.size == len(trains) assert out.sum() == 1
def test_spike_argmax_randomised(): train1 = SpikeTrain([3, 4] * s, t_stop=10) train2 = SpikeTrain([5, 6] * s, t_stop=10) trains = np.array([train1, train2]) out = v.spike_argmax_randomise(trains) assert out.size == len(trains) assert out.sum() == 1
def test_compute_orientation_tuning(): from neo.core import SpikeTrain import quantities as pq from exana.stimulus.tools import (make_orientation_trials, compute_orientation_tuning) trials = [ SpikeTrain(np.arange(0, 10, 2) * pq.s, t_stop=10 * pq.s, orient=315. * pq.deg), SpikeTrain(np.arange(0, 10, 0.5) * pq.s, t_stop=10 * pq.s, orient=np.pi / 3 * pq.rad), SpikeTrain(np.arange(0, 10, 1) * pq.s, t_stop=10 * pq.s, orient=0 * pq.deg), SpikeTrain(np.arange(0, 10, 0.3) * pq.s, t_stop=10 * pq.s, orient=np.pi / 3 * pq.rad) ] sorted_orients = np.array( [0, (np.pi / 3 * pq.rad).rescale(pq.deg) / pq.deg, 315]) * pq.deg rates_e = np.array([1., 2.7, 0.5]) / pq.s trials = make_orientation_trials(trials) rates, orients = compute_orientation_tuning(trials) assert ((rates == rates_e).all()) assert (rates.units == rates_e.units) assert ((orients == sorted_orients).all()) assert (orients.units == sorted_orients.units)
def test_make_orientation_trials(): from neo.core import SpikeTrain from exana.stimulus.tools import (make_orientation_trials, _convert_string_to_quantity_scalar) trials = [ SpikeTrain(np.arange(0, 10, 2) * pq.s, t_stop=10 * pq.s, orient=315. * pq.deg), SpikeTrain(np.arange(0, 10, 0.5) * pq.s, t_stop=10 * pq.s, orient=np.pi / 3 * pq.rad), SpikeTrain(np.arange(0, 10, 1) * pq.s, t_stop=10 * pq.s, orient=0 * pq.deg), SpikeTrain(np.arange(0, 10, 0.3) * pq.s, t_stop=10 * pq.s, orient=np.pi / 3 * pq.rad) ] sorted_trials = [[trials[2]], [trials[1], trials[3]], [trials[0]]] sorted_orients = [ 0 * pq.deg, (np.pi / 3 * pq.rad).rescale(pq.deg), 315 * pq.deg ] orient_trials = make_orientation_trials(trials, unit=pq.deg) for (key, value), trial, orient in zip(orient_trials.items(), sorted_trials, sorted_orients): key = _convert_string_to_quantity_scalar(key) assert (key == orient.magnitude) for t, st in zip(value, trial): assert ((t == st).all()) assert (t.t_start == st.t_start) assert (t.t_stop == st.t_stop) assert (t.annotations["orient"] == orient)
def test_spike_argmax_regular(): train1 = SpikeTrain([2, 20] * s, t_stop=100) train2 = SpikeTrain([31, 40, 61] * s, t_stop=100) train3 = SpikeTrain([1, 4] * s, t_stop=100) train4 = SpikeTrain([1, 47] * s, t_stop=100) trains = np.array([train1, train2, train3, train4]) out = v.spike_argmax(trains) assert np.array_equal(out, np.array([0, 1, 0, 0]))
def _extract_spikes(self, data, metadata, channel_index, lazy): spiketrain = None if lazy: if channel_index in data[:, 1]: spiketrain = SpikeTrain([], units=pq.ms, t_stop=0.0) spiketrain.lazy_shape = None else: spike_times = self._extract_array(data, channel_index) if len(spike_times) > 0: spiketrain = SpikeTrain(spike_times, units=pq.ms, t_stop=spike_times.max()) if spiketrain is not None: spiketrain.annotate(label=metadata["label"], channel_index=channel_index, dt=metadata["dt"]) return spiketrain
def test_annotations(self): self.testfilename = self.get_filename_path('nixio_fr_ann.nix') with NixIO(filename=self.testfilename, mode='ow') as io: annotations = {'my_custom_annotation': 'hello block'} bl = Block(**annotations) annotations = {'something': 'hello hello000'} seg = Segment(**annotations) an =AnalogSignal([[1, 2, 3], [4, 5, 6]], units='V', sampling_rate=1*pq.Hz) an.annotations['ansigrandom'] = 'hello chars' sp = SpikeTrain([3, 4, 5]* s, t_stop=10.0) sp.annotations['railway'] = 'hello train' ev = Event(np.arange(0, 30, 10)*pq.Hz, labels=np.array(['trig0', 'trig1', 'trig2'], dtype='S')) ev.annotations['venue'] = 'hello event' ev2 = Event(np.arange(0, 30, 10) * pq.Hz, labels=np.array(['trig0', 'trig1', 'trig2'], dtype='S')) ev2.annotations['evven'] = 'hello ev' seg.spiketrains.append(sp) seg.events.append(ev) seg.events.append(ev2) seg.analogsignals.append(an) bl.segments.append(seg) io.write_block(bl) io.close() with NixIOfr(filename=self.testfilename) as frio: frbl = frio.read_block() assert 'my_custom_annotation' in frbl.annotations assert 'something' in frbl.segments[0].annotations # assert 'ansigrandom' in frbl.segments[0].analogsignals[0].annotations assert 'railway' in frbl.segments[0].spiketrains[0].annotations assert 'venue' in frbl.segments[0].events[0].annotations assert 'evven' in frbl.segments[0].events[1].annotations os.remove(self.testfilename)
def _handle_processing_group(self, block): # todo: handle other modules than Units units_group = self._file.get('processing/Units/UnitTimes') segment_map = dict( (segment.name, segment) for segment in block.segments) for name, group in units_group.items(): if name == 'unit_list': pass # todo else: segment_name = group['source'].value #desc = group['unit_description'].value # use this to store Neo Unit id? segment = segment_map[segment_name] if self._lazy: times = np.array(()) lazy_shape = group['times'].shape else: times = group['times'].value spiketrain = SpikeTrain( times, units=pq.second, t_stop=group['t_stop'].value * pq.second ) # todo: this is a custom Neo value, general NWB files will not have this - use segment.t_stop instead in that case? if self._lazy: spiketrain.lazy_shape = lazy_shape spiketrain.segment = segment segment.spiketrains.append(spiketrain)
def output_fire_freq(spiketime_dict,isi_binsize,isibins,max_time): import elephant from neo.core import AnalogSignal,SpikeTrain import quantities as q ratebins=np.arange(0,np.ceil(max_time),isi_binsize) #spike_rate: across entire time spike_rate={key:np.zeros((len(spike_set),len(ratebins))) for key,spike_set in spiketime_dict.items()} spike_rate_vs_time_mean={};spike_rate_vs_time_std={} #spike_freq: segmented into pre,stim,post spike_freq={key:{} for key in spiketime_dict.keys()} spike_freq_mean={key:{} for key in spiketime_dict.keys()} spike_freq_std={key:{} for key in spiketime_dict.keys()} for key,spike_set in spiketime_dict.items(): #iterate over different stimulation conditions for i in range(len(spike_set)): #iterate over trials train=SpikeTrain(spike_set[i]*q.s,t_stop=np.ceil(max_time)*q.s) spike_rate[key][i]=elephant.statistics.instantaneous_rate(train,isi_binsize*q.s).magnitude[:,0]#/len(spike_set) for key,rate_set in spike_rate.items(): #separate out the spike_rate into pre, post, and stimulation time frames for pre_post,binlist in isibins.items(): binmin_idx=np.abs(ratebins-binlist[0]).argmin() binmax_idx=np.abs(ratebins-(binlist[-1]+isi_binsize)).argmin() spike_freq[key][pre_post]=spike_rate[key][:,binmin_idx:binmax_idx] spike_rate_vs_time_mean[key]=np.mean(spike_rate[key],axis=0) #average across trials spike_rate_vs_time_std[key]=np.std(spike_rate[key],axis=0) #std across trials for pre_post,binlist in isibins.items(): spike_freq_mean[key][pre_post]=np.mean(spike_freq[key][pre_post],axis=0) spike_freq_std[key][pre_post]=np.std(spike_freq[key][pre_post],axis=0) return spike_freq_mean,spike_freq_std,spike_rate_vs_time_mean,spike_rate_vs_time_std,ratebins
def read_spiketrain(fh, block_id, array_id): nix_block = fh.handle.blocks[block_id] nix_da = nix_block.data_arrays[array_id] params = { 'times': nix_da[:], # TODO think about lazy data loading 'dtype': nix_da.dtype, 't_start': Reader.Help.read_quantity(nix_da.metadata, 't_start'), 't_stop': Reader.Help.read_quantity(nix_da.metadata, 't_stop') } name = Reader.Help.get_obj_neo_name(nix_da) if name: params['name'] = name if 'left_sweep' in nix_da.metadata: params['left_sweep'] = Reader.Help.read_quantity(nix_da.metadata, 'left_sweep') if len(nix_da.dimensions) > 0: s_dim = nix_da.dimensions[0] params['sampling_rate'] = s_dim.sampling_interval * getattr(pq, s_dim.unit) if nix_da.unit: params['units'] = nix_da.unit st = SpikeTrain(**params) for key, value in Reader.Help.read_attributes(nix_da.metadata, 'spiketrain').items(): setattr(st, key, value) st.annotations = Reader.Help.read_annotations(nix_da.metadata, 'spiketrain') return st
def read_segment( self, lazy=False, delimiter="\t", t_start=0. * pq.s, unit=pq.s, ): """ delimiter is the columns delimiter in file "\t" or one space or two space or "," or ";" t_start is the time start of all spiketrain 0 by default. unit is the unit of spike times, can be a str or directly a quantity. """ assert not lazy, "Do not support lazy" unit = pq.Quantity(1, unit) seg = Segment(file_origin=os.path.basename(self.filename)) f = open(self.filename, "Ur") for i, line in enumerate(f): alldata = line[:-1].split(delimiter) if alldata[-1] == "": alldata = alldata[:-1] if alldata[0] == "": alldata = alldata[1:] spike_times = np.array(alldata).astype("f") t_stop = spike_times.max() * unit sptr = SpikeTrain(spike_times * unit, t_start=t_start, t_stop=t_stop) sptr.annotate(channel_index=i) seg.spiketrains.append(sptr) f.close() seg.create_many_to_one_relationship() return seg
def load(self, time_slice=None, strict_slicing=True): """ Load SpikeTrainProxy args: :param time_slice: None or tuple of the time slice expressed with quantities. None is the entire spike train. :param strict_slicing: True by default. Control if an error is raised or not when one of the time_slice members (t_start or t_stop) is outside the real time range of the segment. """ interval = None if time_slice: interval = (float(t) for t in time_slice) # convert from quantities spike_times = self._units_table.get_unit_spike_times( self.id, in_interval=interval) return SpikeTrain( spike_times * self.units, self.t_stop, units=self.units, # sampling_rate=array(1.) * Hz, t_start=self.t_start, # waveforms=None, # left_sweep=None, name=self.name, # file_origin=None, # description=None, # array_annotations=None, **self.annotations)
def test__construct_subsegment_by_unit(self): nb_seg = 3 nb_unit = 7 unit_with_sig = [0, 2, 5] signal_types = ['Vm', 'Conductances'] sig_len = 100 #recordingchannelgroups rcgs = [ RecordingChannelGroup(name = 'Vm', channel_indexes = unit_with_sig), RecordingChannelGroup(name = 'Conductance', channel_indexes = unit_with_sig), ] # Unit all_unit = [ ] for u in range(nb_unit): un = Unit(name = 'Unit #%d' % u, channel_indexes = [u]) all_unit.append(un) bl = Block() for s in range(nb_seg): seg = Segment(name = 'Simulation %s' % s) for j in range(nb_unit): st = SpikeTrain([1, 2, 3], units = 'ms', t_start = 0., t_stop = 10) st.unit = all_unit[j] for t in signal_types: anasigarr = AnalogSignalArray( np.zeros((sig_len, len(unit_with_sig)) ), units = 'nA', sampling_rate = 1000.*pq.Hz, channel_indexes = unit_with_sig ) seg.analogsignalarrays.append(anasigarr) # what you want subseg = seg.construct_subsegment_by_unit(all_unit[:4])
def train_to_neo(train, framerate=24): ''' convert a single spike train to Neo SpikeTrain ''' times = np.nonzero(train)[0] times = times / framerate * pq.s t_stop = train.shape[-1] / framerate * pq.s spike_train = SpikeTrain(times=times, units=pq.s, t_stop=t_stop) return spike_train
def remove_short_ISIs(self, clusterID, ISIThreshold): ''' Removes second spike in cases where ISI is less than a given threshold. WARNING: Modifies cluster in-place. :param clusterID: ID of cluster to be modified :param ISIThreshold: spikes with ISIs less than threshold will be removed :return: N/A ''' # TODO: if we have an ISI below threshold, do we have to remove BOTH spikes? # if duplicate spike times due to cluster misalignment, then no... # if contamination then yes (we do not know which spike is `correct`) spikeTimes = self.clusters[clusterID].spiketrains[0].times.magnitude dt = np.zeros(len(spikeTimes)) dt[:-1] = np.diff(spikeTimes) dt[-1] = 1e6 duplicateSpikeDelays = spikeTimes[np.where(np.abs(dt) <= ISIThreshold)] keepSpikeTimes = spikeTimes[np.where(np.abs(dt) > ISIThreshold)] keepAmplitudes = self.clusters[clusterID].templateAmplitudes[np.where( np.abs(dt) > ISIThreshold)] spikeTrain = SpikeTrain(keepSpikeTimes, units='sec', t_stop=max(spikeTimes), t_start=min(0, min(spikeTimes))) self.clusters[clusterID].spiketrains[0] = spikeTrain self.clusters[clusterID].templateAmplitudes = keepAmplitudes self.clusters[clusterID].mergedDuplicates = True self.clusters[clusterID].nrMergedDuplicateSpikes = len( duplicateSpikeDelays)
def __init__(self, clusterID, group, spikeTimes, waveForm=None, shank=None, maxChannel=None, coordinates=None, firingRate=None): super(Unit, self).__init__(name=str(clusterID)) self.clusterID = clusterID self.group = group if hasattr(spikeTimes, 'times'): self.spiketrains.append(spikeTimes) else: newSpiketrain = SpikeTrain(spikeTimes, units='sec', t_stop=max(spikeTimes)) self.spiketrains.append(newSpiketrain) self.waveForm = waveForm self.shank = shank self.maxChannel = maxChannel self.coordinates = coordinates self.firingRate = firingRate self.mergedDuplicates = False self.nrMergedDuplicateSpikes = 0
def proc_src_condition_unit_repetition(sweep, damaIndex, timeStamp, sweepLen, side, ADperiod, respWin, filename): '''Get the repetion for a unit in a condition in a src file that has been processed by the official matlab function. See proc_src for details''' damaIndex = damaIndex.astype('int32') if len(sweep): times = np.array([res[0, 0] for res in sweep['time']]) shapes = np.concatenate([res.flatten()[np.newaxis][np.newaxis] for res in sweep['shape']], axis=0) trig2 = np.array([res[0, 0] for res in sweep['trig2']]) else: times = np.array([]) shapes = np.array([[[]]]) trig2 = np.array([]) times = pq.Quantity(times, units=pq.ms, dtype=np.float32) t_start = pq.Quantity(0, units=pq.ms, dtype=np.float32) t_stop = pq.Quantity(sweepLen, units=pq.ms, dtype=np.float32) trig2 = pq.Quantity(trig2, units=pq.ms, dtype=np.uint8) waveforms = pq.Quantity(shapes, dtype=np.int8, units=pq.mV) sampling_period = pq.Quantity(ADperiod, units=pq.us) train = SpikeTrain(times=times, t_start=t_start, t_stop=t_stop, trig2=trig2, dtype=np.float32, timestamp=timeStamp, dama_index=damaIndex, side=side, copy=True, respwin=respWin, waveforms=waveforms, file_origin=filename) train.annotations['side'] = side train.sampling_period = sampling_period return train
def __save_segment(self): ''' Write the segment to the Block if it exists ''' # if this is the beginning of the first condition, then we don't want # to save, so exit # but set __seg from None to False so we know next time to create a # segment even if there are no spike in the condition if self.__seg is None: self.__seg = False return if not self.__seg: # create dummy values if there are no SpikeTrains in this condition self.__seg = Segment(file_origin=self._filename, **self.__params) self.__spiketimes = [] times = pq.Quantity(self.__spiketimes, dtype=np.float32, units=pq.ms) train = SpikeTrain(times, t_start=0 * pq.ms, t_stop=self.__t_stop * pq.ms, file_origin=self._filename) self.__seg.spiketrains = [train] self.__unit.spiketrains.append(train) self._blk.segments.append(self.__seg) # set an empty segment # from now on, we need to set __seg to False rather than None so # that if there is a condition with no SpikeTrains we know # to create an empty Segment self.__seg = False
def read_spiketrain(self , # the 2 first key arguments are imposed by neo.io API lazy = False, cascade = True, segment_duration = 15., t_start = -1, channel_index = 0, ): """ With this IO SpikeTrain can e acces directly with its channel number """ # There are 2 possibles behaviour for a SpikeTrain # holding many Spike instance or directly holding spike times # we choose here the first : if not HAVE_SCIPY: raise SCIPY_ERR num_spike_by_spiketrain = 40 sr = 10000. if lazy: times = [ ] else: times = (np.random.rand(num_spike_by_spiketrain)*segment_duration + t_start) # create a spiketrain spiketr = SpikeTrain(times, t_start = t_start*pq.s, t_stop = (t_start+segment_duration)*pq.s , units = pq.s, name = 'it is a spiketrain from exampleio', ) if lazy: # we add the attribute lazy_shape with the size if loaded spiketr.lazy_shape = (num_spike_by_spiketrain,) # ours spiketrains also hold the waveforms: # 1 generate a fake spike shape (2d array if trodness >1) w1 = -stats.nct.pdf(np.arange(11,60,4), 5,20)[::-1]/3. w2 = stats.nct.pdf(np.arange(11,60,2), 5,20) w = np.r_[ w1 , w2 ] w = -w/max(w) if not lazy: # in the neo API the waveforms attr is 3 D in case tetrode # in our case it is mono electrode so dim 1 is size 1 waveforms = np.tile( w[np.newaxis,np.newaxis,:], ( num_spike_by_spiketrain ,1, 1) ) waveforms *= np.random.randn(*waveforms.shape)/6+1 spiketr.waveforms = waveforms*pq.mV spiketr.sampling_rate = sr * pq.Hz spiketr.left_sweep = 1.5* pq.s # for attributes out of neo you can annotate spiketr.annotate(channel_index = channel_index) return spiketr
def test_multiref_write(self): blk = Block("blk1") signal = AnalogSignal(name="sig1", signal=[0, 1, 2], units="mV", sampling_period=pq.Quantity(1, "ms")) othersignal = IrregularlySampledSignal(name="i1", signal=[0, 0, 0], units="mV", times=[1, 2, 3], time_units="ms") event = Event(name="Evee", times=[0.3, 0.42], units="year") epoch = Epoch(name="epoche", times=[0.1, 0.2] * pq.min, durations=[0.5, 0.5] * pq.min) st = SpikeTrain(name="the train of spikes", times=[0.1, 0.2, 10.3], t_stop=11, units="us") for idx in range(3): segname = "seg" + str(idx) seg = Segment(segname) blk.segments.append(seg) seg.analogsignals.append(signal) seg.irregularlysampledsignals.append(othersignal) seg.events.append(event) seg.epochs.append(epoch) seg.spiketrains.append(st) chidx = ChannelIndex([10, 20, 29]) seg = blk.segments[0] st = SpikeTrain(name="choochoo", times=[10, 11, 80], t_stop=1000, units="s") seg.spiketrains.append(st) blk.channel_indexes.append(chidx) for idx in range(6): unit = Unit("unit" + str(idx)) chidx.units.append(unit) unit.spiketrains.append(st) self.writer.write_block(blk) self.compare_blocks([blk], self.reader.blocks)
def test_read_spiketrain_using_eager(self): io = self.io_cls(self.test_file) st3 = io.read_spiketrain(lazy=False, channel_index=3) self.assertIsInstance(st3, SpikeTrain) assert_arrays_equal(st3, SpikeTrain(np.arange(3, 104, dtype=float), t_start=0*pq.s, t_stop=104*pq.s, units=pq.ms))
def find_peak(seg, varname): for i, sig in enumerate(seg.analogsignals): if sig.name in varname: print("finding peaks for "+sig.name) p = ssp.find_peaks_cwt(sig.flatten(), np.arange(100, 550), min_length=10) peaks = p / sig.sampling_rate + sig.t_start peaks = SpikeTrain(times=peaks, t_start=sig.t_start, t_stop=sig.t_stop, name=sig.name) seg.spiketrains.append(peaks)
def test_read_segment_containing_spiketrains_using_eager_cascade(self): io = self.io_cls(self.test_file) segment = io.read_segment(lazy=False, cascade=True) self.assertIsInstance(segment, Segment) self.assertEqual(len(segment.spiketrains), NCELLS) st0 = segment.spiketrains[0] self.assertIsInstance(st0, SpikeTrain) assert_arrays_equal(st0, SpikeTrain(np.arange(0, 101, dtype=float), t_start=0*pq.s, t_stop=101*pq.ms, units=pq.ms)) st4 = segment.spiketrains[4] self.assertIsInstance(st4, SpikeTrain) assert_arrays_equal(st4, SpikeTrain(np.arange(4, 105, dtype=float), t_start=0*pq.s, t_stop=105*pq.ms, units=pq.ms))
def load(self, time_slice=None, strict_slicing=True, magnitude_mode='rescaled', load_waveforms=False): ''' *Args*: :time_slice: None or tuple of the time slice expressed with quantities. None is the entire signal. :strict_slicing: True by default. Control if an error is raise or not when one of time_slice member (t_start or t_stop) is outside the real time range of the segment. :magnitude_mode: 'rescaled' or 'raw'. :load_waveforms: bool load waveforms or not. ''' t_start, t_stop = consolidate_time_slice(time_slice, self.t_start, self.t_stop, strict_slicing) _t_start, _t_stop = prepare_time_slice(time_slice) spike_timestamps = self._rawio.get_spike_timestamps(block_index=self._block_index, seg_index=self._seg_index, unit_index=self._unit_index, t_start=_t_start, t_stop=_t_stop) if magnitude_mode == 'raw': # we must modify a bit the neo.rawio interface to also read the spike_timestamps # underlying clock wich is not always same as sigs raise(NotImplementedError) elif magnitude_mode == 'rescaled': dtype = 'float64' spike_times = self._rawio.rescale_spike_timestamp(spike_timestamps, dtype=dtype) units = 's' if load_waveforms: assert self.sampling_rate is not None, 'Do not have waveforms' raw_wfs = self._rawio.get_spike_raw_waveforms(block_index=self._block_index, seg_index=self._seg_index, unit_index=self._unit_index, t_start=_t_start, t_stop=_t_stop) if magnitude_mode == 'rescaled': float_wfs = self._rawio.rescale_waveforms_to_float(raw_wfs, dtype='float32', unit_index=self._unit_index) waveforms = pq.Quantity(float_wfs, units=self._wf_units, dtype='float32', copy=False) elif magnitude_mode == 'raw': # could code also CompundUnit here but it is over killed # so we used dimentionless waveforms = pq.Quantity(raw_wfs, units='', dtype=raw_wfs.dtype, copy=False) else: waveforms = None sptr = SpikeTrain(spike_times, t_stop, units=units, dtype=dtype, t_start=t_start, copy=False, sampling_rate=self.sampling_rate, waveforms=waveforms, left_sweep=self.left_sweep, name=self.name, file_origin=self.file_origin, description=self.description, **self.annotations) return sptr
def test_no_segment_write(self): # Tests storing AnalogSignal, IrregularlySampledSignal, and SpikeTrain # objects in the secondary (ChannelIndex) substructure without them # being attached to a Segment. blk = Block("segmentless block") signal = AnalogSignal(name="sig1", signal=[0, 1, 2], units="mV", sampling_period=pq.Quantity(1, "ms")) othersignal = IrregularlySampledSignal(name="i1", signal=[0, 0, 0], units="mV", times=[1, 2, 3], time_units="ms") sta = SpikeTrain(name="the train of spikes", times=[0.1, 0.2, 10.3], t_stop=11, units="us") stb = SpikeTrain(name="the train of spikes b", times=[1.1, 2.2, 10.1], t_stop=100, units="ms") chidx = ChannelIndex([8, 13, 21]) blk.channel_indexes.append(chidx) chidx.analogsignals.append(signal) chidx.irregularlysampledsignals.append(othersignal) unit = Unit() chidx.units.append(unit) unit.spiketrains.extend([sta, stb]) self.writer.write_block(blk) self.compare_blocks([blk], self.reader.blocks) self.writer.close() reader = NixIO(self.filename, "ro") blk = reader.read_block(neoname="segmentless block") chx = blk.channel_indexes[0] self.assertEqual(len(chx.analogsignals), 1) self.assertEqual(len(chx.irregularlysampledsignals), 1) self.assertEqual(len(chx.units[0].spiketrains), 2)
def _read_spiketrain(self, node, parent): attributes = self._get_standard_attributes(node) t_start = self._get_quantity(node["t_start"]) t_stop = self._get_quantity(node["t_stop"]) # todo: handle sampling_rate, waveforms, left_sweep spiketrain = SpikeTrain(self._get_quantity(node["times"]), t_start=t_start, t_stop=t_stop, **attributes) spiketrain.segment = parent self.object_refs[node.attrs["object_ref"]] = spiketrain return spiketrain
def spiketrain_from_raw(channel: AnalogSignal, threshold: Quantity) -> SpikeTrain: assert channel.signal.shape[1] == 1 spikes = find_spikes(channel, threshold) signal = channel.signal.squeeze() times = np.array([channel_index_to_time(channel, start) for start, _ in spikes]) * channel.sampling_period.units waveforms = np.array([[signal[start:stop]] for start, stop in spikes]) * channel.units name = f"{channel.name}#{threshold}" result = SpikeTrain(name=name, times=times, units=channel.units, t_start=channel.t_start, t_stop=channel.t_stop, waveforms=waveforms, sampling_rate=channel.sampling_rate, sort=True) result.annotate(from_channel=channel.name, threshold=threshold) return result
def _extract_spikes(self, data, metadata, channel_index): spiketrain = None spike_times = self._extract_array(data, channel_index) if len(spike_times) > 0: spiketrain = SpikeTrain(spike_times, units=pq.ms, t_stop=spike_times.max()) spiketrain.annotate(label=metadata["label"], channel_index=channel_index, dt=metadata["dt"]) return spiketrain