Example #1
0
    def test__children(self):
        segment = Segment(name='seg1')
        segment.spikes = [self.spike1]
        segment.create_many_to_one_relationship()

        unit = Unit(name='unit1')
        unit.spikes = [self.spike1]
        unit.create_many_to_one_relationship()

        self.assertEqual(self.spike1._single_parent_objects,
                         ('Segment', 'Unit'))
        self.assertEqual(self.spike1._multi_parent_objects, ())

        self.assertEqual(self.spike1._single_parent_containers,
                         ('segment', 'unit'))
        self.assertEqual(self.spike1._multi_parent_containers, ())

        self.assertEqual(self.spike1._parent_objects,
                         ('Segment', 'Unit'))
        self.assertEqual(self.spike1._parent_containers,
                         ('segment', 'unit'))

        self.assertEqual(len(self.spike1.parents), 2)
        self.assertEqual(self.spike1.parents[0].name, 'seg1')
        self.assertEqual(self.spike1.parents[1].name, 'unit1')

        assert_neo_object_is_compliant(self.spike1)
Example #2
0
    def setup_units(self):
        params = {'testarg2': 'yes', 'testarg3': True}
        self.unit1 = Unit(name='test', description='tester 1',
                          file_origin='test.file',
                          channel_indexes=np.array([1]),
                          testarg1=1, **params)
        self.unit2 = Unit(name='test', description='tester 2',
                          file_origin='test.file',
                          channel_indexes=np.array([2]),
                          testarg1=1, **params)
        self.unit1.annotate(testarg1=1.1, testarg0=[1, 2, 3])
        self.unit2.annotate(testarg11=1.1, testarg10=[1, 2, 3])

        self.unit1train = [self.train1[0], self.train2[1]]
        self.unit2train = [self.train1[1], self.train2[0]]

        self.unit1.spiketrains = self.unit1train
        self.unit2.spiketrains = self.unit2train

        self.unit1spike = [self.spike1[0], self.spike2[1]]
        self.unit2spike = [self.spike1[1], self.spike2[0]]

        self.unit1.spikes = self.unit1spike
        self.unit2.spikes = self.unit2spike

        self.unit1.create_many_to_one_relationship()
        self.unit2.create_many_to_one_relationship()
def proc_src_units(srcfile, filename):
    '''Get the units in an src file that has been processed by the official
    matlab function.  See proc_src for details'''
    rcg = RecordingChannelGroup(file_origin=filename)
    un_unit = Unit(name='UnassignedSpikes',
                   file_origin=filename,
                   elliptic=[],
                   boundaries=[],
                   timestamp=[],
                   max_valid=[])

    rcg.units.append(un_unit)

    sortInfo = srcfile['sortInfo'][0, 0]
    timeslice = sortInfo['timeslice'][0, 0]
    maxValid = timeslice['maxValid'][0, 0]
    cluster = timeslice['cluster'][0, 0]
    if len(cluster):
        maxValid = maxValid[0, 0]
        elliptic = [res.flatten() for res in cluster['elliptic'].flatten()]
        boundaries = [res.flatten() for res in cluster['boundaries'].flatten()]
        fullclust = zip(elliptic, boundaries)
        for ielliptic, iboundaries in fullclust:
            unit = Unit(file_origin=filename,
                        boundaries=[iboundaries],
                        elliptic=[ielliptic],
                        timeStamp=[],
                        max_valid=[maxValid])
            rcg.units.append(unit)
    return rcg
Example #4
0
    def setup_units(self):
        params = {'testarg2': 'yes', 'testarg3': True}
        self.unit1 = Unit(name='test', description='tester 1',
                          file_origin='test.file',
                          channel_indexes=np.array([1]),
                          testarg1=1, **params)
        self.unit2 = Unit(name='test', description='tester 2',
                          file_origin='test.file',
                          channel_indexes=np.array([2]),
                          testarg1=1, **params)
        self.unit1.annotate(testarg1=1.1, testarg0=[1, 2, 3])
        self.unit2.annotate(testarg11=1.1, testarg10=[1, 2, 3])

        self.unit1train = [self.train1[0], self.train2[1]]
        self.unit2train = [self.train1[1], self.train2[0]]

        self.unit1.spiketrains = self.unit1train
        self.unit2.spiketrains = self.unit2train

        self.unit1spike = [self.spike1[0], self.spike2[1]]
        self.unit2spike = [self.spike1[1], self.spike2[0]]

        self.unit1.spikes = self.unit1spike
        self.unit2.spikes = self.unit2spike

        create_many_to_one_relationship(self.unit1)
        create_many_to_one_relationship(self.unit2)
Example #5
0
 def _read_unit(self, node, parent):
     attributes = self._get_standard_attributes(node)
     spiketrains = []
     for name, child_node in node["spiketrains"].items():
         if "SpikeTrain" in name:
             obj_ref = child_node.attrs["object_ref"]
             spiketrains.append(self.object_refs[obj_ref])
     unit = Unit(**attributes)
     unit.channel_index = parent
     unit.spiketrains = spiketrains
     return unit
Example #6
0
 def _read_unit(self, node, parent):
     attributes = self._get_standard_attributes(node)
     spiketrains = []
     for name, child_node in node["spiketrains"].items():
         if "SpikeTrain" in name:
             obj_ref = child_node.attrs["object_ref"]
             spiketrains.append(self.object_refs[obj_ref])
     unit = Unit(**attributes)
     unit.channel_index = parent
     unit.spiketrains = spiketrains
     return unit
Example #7
0
    def read_channelindex(self,
                          path,
                          cascade=True,
                          lazy=False,
                          read_waveforms=True):
        channel_group = self._exdir_directory[path]
        group_id = channel_group.attrs['electrode_group_id']
        chx = ChannelIndex(
            name='Channel group {}'.format(group_id),
            index=channel_group.attrs['electrode_idx'],
            channel_ids=channel_group.attrs['electrode_identities'],
            **{
                'group_id': group_id,
                'exdir_path': path
            })
        if 'LFP' in channel_group:
            for lfp_group in channel_group['LFP'].values():
                ana = self.read_analogsignal(lfp_group.name,
                                             cascade=cascade,
                                             lazy=lazy)
                chx.analogsignals.append(ana)
                ana.channel_index = chx
        if 'MUA' in channel_group:
            for mua_group in channel_group['MUA'].values():
                ana = self.read_analogsignal(mua_group.name,
                                             cascade=cascade,
                                             lazy=lazy)
                chx.analogsignals.append(ana)
                ana.channel_index = chx
        sptrs = []
        if 'UnitTimes' in channel_group:
            for unit_group in channel_group['UnitTimes'].values():
                unit = self.read_unit(unit_group.name,
                                      cascade=cascade,
                                      lazy=lazy,
                                      read_waveforms=read_waveforms)
                unit.channel_index = chx
                chx.units.append(unit)
                sptr = unit.spiketrains[0]
                sptr.channel_index = chx

        elif 'EventWaveform' in channel_group:
            sptr = self.read_spiketrain(channel_group['EventWaveform'].name,
                                        cascade=cascade,
                                        lazy=lazy,
                                        read_waveforms=read_waveforms)
            unit = Unit(name=sptr.name, **sptr.annotations)
            unit.spiketrains.append(sptr)
            unit.channel_index = chx
            sptr.channel_index = chx
            chx.units.append(unit)
        return chx
Example #8
0
    def test__construct_subsegment_by_unit(self):
        nb_seg = 3
        nb_unit = 7
        unit_with_sig = [0, 2, 5]
        signal_types = ['Vm', 'Conductances']
        sig_len = 100

        #recordingchannelgroups
        rcgs = [ RecordingChannelGroup(name = 'Vm', channel_indexes = unit_with_sig),
                        RecordingChannelGroup(name = 'Conductance', channel_indexes = unit_with_sig), ]

        # Unit
        all_unit = [ ]
        for u in range(nb_unit):
            un = Unit(name = 'Unit #%d' % u, channel_indexes = [u])
            all_unit.append(un)

        bl = Block()
        for s in range(nb_seg):
            seg = Segment(name = 'Simulation %s' % s)
            for j in range(nb_unit):
                st = SpikeTrain([1, 2, 3], units = 'ms', t_start = 0., t_stop = 10)
                st.unit = all_unit[j]

            for t in signal_types:
                anasigarr = AnalogSignalArray( np.zeros((sig_len, len(unit_with_sig)) ), units = 'nA',
                                sampling_rate = 1000.*pq.Hz, channel_indexes = unit_with_sig )
                seg.analogsignalarrays.append(anasigarr)

        # what you want
        subseg = seg.construct_subsegment_by_unit(all_unit[:4])
Example #9
0
    def create_all_annotated(cls):
        times = cls.rquant(1, pq.s)
        signal = cls.rquant(1, pq.V)
        blk = Block()
        blk.annotate(**cls.rdict(3))
        cls.populate_dates(blk)

        seg = Segment()
        seg.annotate(**cls.rdict(4))
        cls.populate_dates(seg)
        blk.segments.append(seg)

        asig = AnalogSignal(signal=signal, sampling_rate=pq.Hz)
        asig.annotate(**cls.rdict(2))
        seg.analogsignals.append(asig)

        isig = IrregularlySampledSignal(times=times,
                                        signal=signal,
                                        time_units=pq.s)
        isig.annotate(**cls.rdict(2))
        seg.irregularlysampledsignals.append(isig)

        epoch = Epoch(times=times, durations=times)
        epoch.annotate(**cls.rdict(4))
        seg.epochs.append(epoch)

        event = Event(times=times)
        event.annotate(**cls.rdict(4))
        seg.events.append(event)

        spiketrain = SpikeTrain(times=times, t_stop=pq.s, units=pq.s)
        d = cls.rdict(6)
        d["quantity"] = pq.Quantity(10, "mV")
        d["qarray"] = pq.Quantity(range(10), "mA")
        spiketrain.annotate(**d)
        seg.spiketrains.append(spiketrain)

        chx = ChannelIndex(name="achx", index=[1, 2], channel_ids=[0, 10])
        chx.annotate(**cls.rdict(5))
        blk.channel_indexes.append(chx)

        unit = Unit()
        unit.annotate(**cls.rdict(2))
        chx.units.append(unit)

        return blk
Example #10
0
    def read_block(self, lazy=False, cascade=True, **kargs):
        '''
        Reads a block from the simple spike data file "fname" generated
        with BrainWare
        '''

        # there are no keyargs implemented to so far.  If someone tries to pass
        # them they are expecting them to do something or making a mistake,
        # neither of which should pass silently
        if kargs:
            raise NotImplementedError('This method does not have any '
                                      'argument implemented yet')
        self._fsrc = None
        self.__lazy = lazy

        self._blk = Block(file_origin=self._filename)
        block = self._blk

        # if we aren't doing cascade, don't load anything
        if not cascade:
            return block

        # create the objects to store other objects
        chx = ChannelIndex(file_origin=self._filename,
                           index=np.array([], dtype=np.int))
        self.__unit = Unit(file_origin=self._filename)

        # load objects into their containers
        block.channel_indexes.append(chx)
        chx.units.append(self.__unit)

        # initialize values
        self.__t_stop = None
        self.__params = None
        self.__seg = None
        self.__spiketimes = None

        # open the file
        with open(self._path, 'rb') as self._fsrc:
            res = True
            # while the file is not done keep reading segments
            while res:
                res = self.__read_id()

        block.create_many_to_one_relationship()

        # cleanup attributes
        self._fsrc = None
        self.__lazy = False

        self._blk = None

        self.__t_stop = None
        self.__params = None
        self.__seg = None
        self.__spiketimes = None

        return block
    def setup_unit(self):
        unitname11 = 'unit 1 1'
        unitname12 = 'unit 1 2'
        unitname21 = 'unit 2 1'
        unitname22 = 'unit 2 2'

        self.unitnames1 = [unitname11, unitname12]
        self.unitnames2 = [unitname21, unitname22, unitname11]
        self.unitnames = [unitname11, unitname12, unitname21, unitname22]

        unit11 = Unit(name=unitname11, channel_indexes=np.array([1]))
        unit12 = Unit(name=unitname12, channel_indexes=np.array([2]))
        unit21 = Unit(name=unitname21, channel_indexes=np.array([1]))
        unit22 = Unit(name=unitname22, channel_indexes=np.array([2]))
        unit23 = Unit(name=unitname11, channel_indexes=np.array([1]))

        self.units1 = [unit11, unit12]
        self.units2 = [unit21, unit22, unit23]
        self.units = [unit11, unit12, unit21, unit22]
Example #12
0
    def create_all_annotated(cls):
        times = cls.rquant(1, pq.s)
        signal = cls.rquant(1, pq.V)
        blk = Block()
        blk.annotate(**cls.rdict(3))

        seg = Segment()
        seg.annotate(**cls.rdict(4))
        blk.segments.append(seg)

        asig = AnalogSignal(signal=signal, sampling_rate=pq.Hz)
        asig.annotate(**cls.rdict(2))
        seg.analogsignals.append(asig)

        isig = IrregularlySampledSignal(times=times, signal=signal,
                                        time_units=pq.s)
        isig.annotate(**cls.rdict(2))
        seg.irregularlysampledsignals.append(isig)

        epoch = Epoch(times=times, durations=times)
        epoch.annotate(**cls.rdict(4))
        seg.epochs.append(epoch)

        event = Event(times=times)
        event.annotate(**cls.rdict(4))
        seg.events.append(event)

        spiketrain = SpikeTrain(times=times, t_stop=pq.s, units=pq.s)
        d = cls.rdict(6)
        d["quantity"] = pq.Quantity(10, "mV")
        d["qarray"] = pq.Quantity(range(10), "mA")
        spiketrain.annotate(**d)
        seg.spiketrains.append(spiketrain)

        chx = ChannelIndex(name="achx", index=[1, 2], channel_ids=[0, 10])
        chx.annotate(**cls.rdict(5))
        blk.channel_indexes.append(chx)

        unit = Unit()
        unit.annotate(**cls.rdict(2))
        chx.units.append(unit)

        return blk
Example #13
0
    def read_unit(fh, block_id, rcg_source_id, unit_id):
        def read_spiketrains(nix_file):
            strains = filter(lambda x: x.type == 'spiketrain', nix_file.blocks[block_id].data_arrays)
            strains = [x for x in strains if nsn in [y.name for y in x.sources]]
            return [Reader.read_spiketrain(fh, block_id, da.name) for da in strains]

        nix_block = fh.handle.blocks[block_id]
        nix_rcg_source = nix_block.sources[rcg_source_id]
        nix_source = nix_rcg_source.sources[unit_id]
        nsn = nix_source.name

        rcg = Unit(nix_source.name)

        for key, value in Reader.Help.read_attributes(nix_source.metadata, 'unit').items():
            setattr(rcg, key, value)

        rcg.annotations = Reader.Help.read_annotations(nix_source.metadata, 'unit')

        setattr(rcg, 'spiketrains', ProxyList(fh, read_spiketrains))

        return rcg
Example #14
0
    def test__children(self):
        segment = Segment(name="seg1")
        segment.spikes = [self.spike1]
        segment.create_many_to_one_relationship()

        unit = Unit(name="unit1")
        unit.spikes = [self.spike1]
        unit.create_many_to_one_relationship()

        self.assertEqual(self.spike1._container_child_objects, ())
        self.assertEqual(self.spike1._data_child_objects, ())
        self.assertEqual(self.spike1._single_parent_objects, ("Segment", "Unit"))
        self.assertEqual(self.spike1._multi_child_objects, ())
        self.assertEqual(self.spike1._multi_parent_objects, ())
        self.assertEqual(self.spike1._child_properties, ())

        self.assertEqual(self.spike1._single_child_objects, ())

        self.assertEqual(self.spike1._container_child_containers, ())
        self.assertEqual(self.spike1._data_child_containers, ())
        self.assertEqual(self.spike1._single_child_containers, ())
        self.assertEqual(self.spike1._single_parent_containers, ("segment", "unit"))
        self.assertEqual(self.spike1._multi_child_containers, ())
        self.assertEqual(self.spike1._multi_parent_containers, ())

        self.assertEqual(self.spike1._child_objects, ())
        self.assertEqual(self.spike1._child_containers, ())
        self.assertEqual(self.spike1._parent_objects, ("Segment", "Unit"))
        self.assertEqual(self.spike1._parent_containers, ("segment", "unit"))

        self.assertEqual(self.spike1.children, ())
        self.assertEqual(len(self.spike1.parents), 2)
        self.assertEqual(self.spike1.parents[0].name, "seg1")
        self.assertEqual(self.spike1.parents[1].name, "unit1")

        self.spike1.create_many_to_one_relationship()
        self.spike1.create_many_to_many_relationship()
        self.spike1.create_relationship()
        assert_neo_object_is_compliant(self.spike1)
Example #15
0
 def read_unit(self,
               path,
               cascade=True,
               lazy=False,
               cluster_num=None,
               read_waveforms=True):
     group = self._exdir_directory[path]
     assert group.parent.object_name == 'UnitTimes'
     attrs = {'exdir_path': path}
     attrs.update(group.attrs.to_dict())
     unit = Unit(**attrs)
     sptr = self.read_spiketrain(path, cascade, lazy, cluster_num,
                                 read_waveforms)
     unit.spiketrains.append(sptr)
     return unit
Example #16
0
    def test_anonymous_objects_write(self):
        nblocks = 2
        nsegs = 2
        nanasig = 4
        nirrseg = 2
        nepochs = 3
        nevents = 4
        nspiketrains = 3
        nchx = 5
        nunits = 10

        times = self.rquant(1, pq.s)
        signal = self.rquant(1, pq.V)
        blocks = []
        for blkidx in range(nblocks):
            blk = Block()
            blocks.append(blk)
            for segidx in range(nsegs):
                seg = Segment()
                blk.segments.append(seg)
                for anaidx in range(nanasig):
                    seg.analogsignals.append(AnalogSignal(signal=signal,
                                                          sampling_rate=pq.Hz))
                for irridx in range(nirrseg):
                    seg.irregularlysampledsignals.append(
                        IrregularlySampledSignal(times=times,
                                                 signal=signal,
                                                 time_units=pq.s)
                    )
                for epidx in range(nepochs):
                    seg.epochs.append(Epoch(times=times, durations=times))
                for evidx in range(nevents):
                    seg.events.append(Event(times=times))
                for stidx in range(nspiketrains):
                    seg.spiketrains.append(SpikeTrain(times=times,
                                                      t_stop=times[-1]+pq.s,
                                                      units=pq.s))
            for chidx in range(nchx):
                chx = ChannelIndex(name="chx{}".format(chidx),
                                   index=[1, 2],
                                   channel_ids=[11, 22])
                blk.channel_indexes.append(chx)
                for unidx in range(nunits):
                    unit = Unit()
                    chx.units.append(unit)
        self.writer.write_all_blocks(blocks)
        self.compare_blocks(blocks, self.reader.blocks)
Example #17
0
    def test__construct_subsegment_by_unit(self):
        nb_seg = 3
        nb_unit = 7
        unit_with_sig = np.array([0, 2, 5])
        signal_types = ['Vm', 'Conductances']
        sig_len = 100

        # channelindexes
        chxs = [ChannelIndex(name='Vm',
                             index=unit_with_sig),
                ChannelIndex(name='Conductance',
                             index=unit_with_sig)]

        # Unit
        all_unit = []
        for u in range(nb_unit):
            un = Unit(name='Unit #%d' % u, channel_indexes=np.array([u]))
            assert_neo_object_is_compliant(un)
            all_unit.append(un)

        blk = Block()
        blk.channel_indexes = chxs
        for s in range(nb_seg):
            seg = Segment(name='Simulation %s' % s)
            for j in range(nb_unit):
                st = SpikeTrain([1, 2], units='ms',
                                t_start=0., t_stop=10)
                st.unit = all_unit[j]

            for t in signal_types:
                anasigarr = AnalogSignal(np.zeros((sig_len,
                                                   len(unit_with_sig))),
                                         units='nA',
                                         sampling_rate=1000. * pq.Hz,
                                         channel_indexes=unit_with_sig)
                seg.analogsignals.append(anasigarr)

        blk.create_many_to_one_relationship()
        for unit in all_unit:
            assert_neo_object_is_compliant(unit)
        for chx in chxs:
            assert_neo_object_is_compliant(chx)
        assert_neo_object_is_compliant(blk)

        # what you want
        newseg = seg.construct_subsegment_by_unit(all_unit[:4])
        assert_neo_object_is_compliant(newseg)
Example #18
0
    def test_multiref_write(self):
        blk = Block("blk1")
        signal = AnalogSignal(name="sig1",
                              signal=[0, 1, 2],
                              units="mV",
                              sampling_period=pq.Quantity(1, "ms"))
        othersignal = IrregularlySampledSignal(name="i1",
                                               signal=[0, 0, 0],
                                               units="mV",
                                               times=[1, 2, 3],
                                               time_units="ms")
        event = Event(name="Evee", times=[0.3, 0.42], units="year")
        epoch = Epoch(name="epoche",
                      times=[0.1, 0.2] * pq.min,
                      durations=[0.5, 0.5] * pq.min)
        st = SpikeTrain(name="the train of spikes",
                        times=[0.1, 0.2, 10.3],
                        t_stop=11,
                        units="us")

        for idx in range(3):
            segname = "seg" + str(idx)
            seg = Segment(segname)
            blk.segments.append(seg)
            seg.analogsignals.append(signal)
            seg.irregularlysampledsignals.append(othersignal)
            seg.events.append(event)
            seg.epochs.append(epoch)
            seg.spiketrains.append(st)

        chidx = ChannelIndex([10, 20, 29])
        seg = blk.segments[0]
        st = SpikeTrain(name="choochoo",
                        times=[10, 11, 80],
                        t_stop=1000,
                        units="s")
        seg.spiketrains.append(st)
        blk.channel_indexes.append(chidx)
        for idx in range(6):
            unit = Unit("unit" + str(idx))
            chidx.units.append(unit)
            unit.spiketrains.append(st)

        self.writer.write_block(blk)
        self.compare_blocks([blk], self.reader.blocks)
Example #19
0
    def read_block(self, lazy=False, cascade=True, load_waveforms=False):
        """
        """
        # Create block
        bl = Block(file_origin=self.filename)
        if not cascade:
            return bl

        seg = self.read_segment(self.filename,
                                lazy=lazy,
                                cascade=cascade,
                                load_waveforms=load_waveforms)
        bl.segments.append(seg)
        neo.io.tools.populate_RecordingChannel(bl,
                                               remove_from_annotation=False)

        # This create rc and RCG for attaching Units
        rcg0 = bl.recordingchannelgroups[0]

        def find_rc(chan):
            for rc in rcg0.recordingchannels:
                if rc.index == chan:
                    return rc

        for st in seg.spiketrains:
            chan = st.annotations['channel_index']
            rc = find_rc(chan)
            if rc is None:
                rc = RecordingChannel(index=chan)
                rcg0.recordingchannels.append(rc)
                rc.recordingchannelgroups.append(rcg0)
            if len(rc.recordingchannelgroups) == 1:
                rcg = RecordingChannelGroup(name='Group {}'.format(chan))
                rcg.recordingchannels.append(rc)
                rc.recordingchannelgroups.append(rcg)
                bl.recordingchannelgroups.append(rcg)
            else:
                rcg = rc.recordingchannelgroups[1]
            unit = Unit(name=st.name)
            rcg.units.append(unit)
            unit.spiketrains.append(st)
        bl.create_many_to_one_relationship()

        return bl
Example #20
0
    def test_no_segment_write(self):
        # Tests storing AnalogSignal, IrregularlySampledSignal, and SpikeTrain
        # objects in the secondary (ChannelIndex) substructure without them
        # being attached to a Segment.
        blk = Block("segmentless block")
        signal = AnalogSignal(name="sig1",
                              signal=[0, 1, 2],
                              units="mV",
                              sampling_period=pq.Quantity(1, "ms"))
        othersignal = IrregularlySampledSignal(name="i1",
                                               signal=[0, 0, 0],
                                               units="mV",
                                               times=[1, 2, 3],
                                               time_units="ms")
        sta = SpikeTrain(name="the train of spikes",
                         times=[0.1, 0.2, 10.3],
                         t_stop=11,
                         units="us")
        stb = SpikeTrain(name="the train of spikes b",
                         times=[1.1, 2.2, 10.1],
                         t_stop=100,
                         units="ms")

        chidx = ChannelIndex([8, 13, 21])
        blk.channel_indexes.append(chidx)
        chidx.analogsignals.append(signal)
        chidx.irregularlysampledsignals.append(othersignal)

        unit = Unit()
        chidx.units.append(unit)
        unit.spiketrains.extend([sta, stb])
        self.writer.write_block(blk)

        self.compare_blocks([blk], self.reader.blocks)

        self.writer.close()
        reader = NixIO(self.filename, "ro")
        blk = reader.read_block(neoname="segmentless block")
        chx = blk.channel_indexes[0]
        self.assertEqual(len(chx.analogsignals), 1)
        self.assertEqual(len(chx.irregularlysampledsignals), 1)
        self.assertEqual(len(chx.units[0].spiketrains), 2)
Example #21
0
    def test__issue_285(self):
        train = SpikeTrain([3, 4, 5] * pq.s, t_stop=10.0)
        unit = Unit()
        train.unit = unit
        unit.spiketrains.append(train)

        epoch = Epoch([0, 10, 20], [2, 2, 2], ["a", "b", "c"], units="ms")

        blk = Block()
        seg = Segment()
        seg.spiketrains.append(train)
        seg.epochs.append(epoch)
        epoch.segment = seg
        blk.segments.append(seg)

        reader = PickleIO(filename="blk.pkl")
        reader.write(blk)

        reader = PickleIO(filename="blk.pkl")
        r_blk = reader.read_block()
        r_seg = r_blk.segments[0]
        self.assertIsInstance(r_seg.spiketrains[0].unit, Unit)
        self.assertIsInstance(r_seg.epochs[0], Epoch)
Example #22
0
    def test_multiref_write(self):
        blk = Block("blk1")
        signal = AnalogSignal(name="sig1", signal=[0, 1, 2], units="mV",
                              sampling_period=pq.Quantity(1, "ms"))

        for idx in range(3):
            segname = "seg" + str(idx)
            seg = Segment(segname)
            blk.segments.append(seg)
            seg.analogsignals.append(signal)

        chidx = ChannelIndex([10, 20, 29])
        seg = blk.segments[0]
        st = SpikeTrain(name="choochoo", times=[10, 11, 80], t_stop=1000,
                        units="s")
        seg.spiketrains.append(st)
        blk.channel_indexes.append(chidx)
        for idx in range(6):
            unit = Unit("unit" + str(idx))
            chidx.units.append(unit)
            unit.spiketrains.append(st)

        self.writer.write_block(blk)
        self.compare_blocks([blk], self.reader.blocks)
Example #23
0
    def read_block(self,
                   lazy=False,
                   cascade=True,
                   get_waveforms=True,
                   cluster_metadata='all',
                   raw_data_units='uV',
                   get_raw_data=False,
                   ):
        """
        Reads a block with segments and channel_indexes

        Parameters:
        get_waveforms: bool, default = False
            Wether or not to get the waveforms
        get_raw_data: bool, default = False
            Wether or not to get the raw traces
        raw_data_units: str, default = "uV"
            SI units of the raw trace according to voltage_gain given to klusta
        cluster_metadata: str, default = "all"
            Which clusters to load, possibilities are "noise", "unsorted",
            "good", "all", if all is selected noise is omitted.
        """
        assert isinstance(cluster_metadata, str)
        blk = Block()
        if cascade:
            seg = Segment(file_origin=self.filename)
            blk.segments += [seg]
            for model in self.models:
                group_id = model.channel_group
                group_meta = {'group_id': group_id}
                group_meta.update(model.metadata)
                chx = ChannelIndex(name='channel group #{}'.format(group_id),
                                   index=model.channels,
                                   **group_meta)
                blk.channel_indexes.append(chx)
                clusters = model.spike_clusters
                for cluster_id in model.cluster_ids:
                    meta = model.cluster_metadata[cluster_id]
                    if cluster_metadata == 'all':
                        if meta == 'noise':
                            continue
                    elif cluster_metadata != meta:
                        continue
                    sptr = self.read_spiketrain(cluster_id=cluster_id,
                                                model=model, lazy=lazy,
                                                cascade=cascade,
                                                get_waveforms=get_waveforms)
                    sptr.annotations.update({'cluster_metadata': meta,
                                             'group_id': model.channel_group})
                    sptr.channel_index = chx
                    unit = Unit()
                    unit.spiketrains.append(sptr)
                    chx.units.append(unit)
                    unit.channel_index = chx
                    seg.spiketrains.append(sptr)
                if get_raw_data:
                    ana = self.read_analogsignal(model, raw_data_units,
                                                 lazy, cascade)
                    ana.channel_index = chx
                    seg.analogsignals.append(ana)

            seg.duration = model.duration * pq.s

        blk.create_many_to_one_relationship()
        return blk
def proc_f32(filename):
    '''Load an f32 file that has already been processed by the official matlab
    file converter.  That matlab data is saved to an m-file, which is then
    converted to a numpy '.npz' file.  This numpy file is the file actually
    loaded.  This function converts it to a neo block and returns the block.
    This block can be compared to the block produced by BrainwareF32IO to
    make sure BrainwareF32IO is working properly

    block = proc_f32(filename)

    filename: The file name of the numpy file to load.  It should end with
    '*_f32_py?.npz'. This will be converted to a neo 'file_origin' property
    with the value '*.f32', so the filename to compare should fit that pattern.
    'py?' should be 'py2' for the python 2 version of the numpy file or 'py3'
    for the python 3 version of the numpy file.

    example: filename = 'file1_f32_py2.npz'
             f32 file name = 'file1.f32'
    '''

    filenameorig = os.path.basename(filename[:-12] + '.f32')

    # create the objects to store other objects
    block = Block(file_origin=filenameorig)
    chx = ChannelIndex(file_origin=filenameorig,
                       index=np.array([], dtype=np.int),
                       channel_names=np.array([], dtype='S'))
    unit = Unit(file_origin=filenameorig)

    # load objects into their containers
    block.channel_indexes.append(chx)
    chx.units.append(unit)

    try:
        with np.load(filename) as f32obj:
            f32file = f32obj.items()[0][1].flatten()
    except IOError as exc:
        if 'as a pickle' in exc.message:
            block.create_many_to_one_relationship()
            return block
        else:
            raise

    sweeplengths = [res[0, 0].tolist() for res in f32file['sweeplength']]
    stims = [res.flatten().tolist() for res in f32file['stim']]

    sweeps = [res['spikes'].flatten() for res in f32file['sweep'] if res.size]

    fullf32 = zip(sweeplengths, stims, sweeps)
    for sweeplength, stim, sweep in fullf32:
        for trainpts in sweep:
            if trainpts.size:
                trainpts = trainpts.flatten().astype('float32')
            else:
                trainpts = []

            paramnames = ['Param%s' % i for i in range(len(stim))]
            params = dict(zip(paramnames, stim))
            train = SpikeTrain(trainpts,
                               units=pq.ms,
                               t_start=0,
                               t_stop=sweeplength,
                               file_origin=filenameorig)

            segment = Segment(file_origin=filenameorig, **params)
            segment.spiketrains = [train]
            unit.spiketrains.append(train)
            block.segments.append(segment)

    block.create_many_to_one_relationship()

    return block
Example #25
0
    def setUp(self):
        self.fname = '/tmp/test.exdir'
        if os.path.exists(self.fname):
            shutil.rmtree(self.fname)
        self.n_channels = 5
        self.n_samples = 20
        self.n_spikes = 50
        blk = Block()
        seg = Segment()
        blk.segments.append(seg)
        chx1 = ChannelIndex(index=np.arange(self.n_channels),
                            channel_ids=np.arange(self.n_channels))
        chx2 = ChannelIndex(index=np.arange(self.n_channels),
                            channel_ids=np.arange(self.n_channels) * 2)
        blk.channel_indexes.extend([chx1, chx2])

        wf1 = np.random.random(
            (self.n_spikes, self.n_channels, self.n_samples))
        ts1 = np.sort(np.random.random(self.n_spikes))
        t_stop1 = np.ceil(ts1[-1])
        sptr1 = SpikeTrain(
            times=ts1,
            units='s',
            waveforms=np.random.random(
                (self.n_spikes, self.n_channels, self.n_samples)) * pq.V,
            name='spikes 1',
            description='sptr1',
            t_stop=t_stop1,
            **{'id': 1})
        sptr1.channel_index = chx1
        unit1 = Unit(name='unit 1')
        unit1.spiketrains.append(sptr1)
        chx1.units.append(unit1)
        seg.spiketrains.append(sptr1)

        ts2 = np.sort(np.random.random(self.n_spikes))
        t_stop2 = np.ceil(ts2[-1])
        sptr2 = SpikeTrain(
            times=ts2,
            units='s',
            waveforms=np.random.random(
                (self.n_spikes, self.n_channels, self.n_samples)) * pq.V,
            description='sptr2',
            name='spikes 2',
            t_stop=t_stop2,
            **{'id': 2})
        sptr2.channel_index = chx2
        unit2 = Unit(name='unit 2')
        unit2.spiketrains.append(sptr2)
        chx2.units.append(unit2)
        seg.spiketrains.append(sptr2)

        wf3 = np.random.random(
            (self.n_spikes, self.n_channels, self.n_samples))
        ts3 = np.sort(np.random.random(self.n_spikes))
        t_stop3 = np.ceil(ts3[-1])
        sptr3 = SpikeTrain(
            times=ts3,
            units='s',
            waveforms=np.random.random(
                (self.n_spikes, self.n_channels, self.n_samples)) * pq.V,
            description='sptr3',
            name='spikes 3',
            t_stop=t_stop3,
            **{'id': 3})
        sptr3.channel_index = chx2
        unit3 = Unit(name='unit 3')
        unit3.spiketrains.append(sptr3)
        chx2.units.append(unit3)
        seg.spiketrains.append(sptr3)

        t_stop = max([t_stop1, t_stop2, t_stop3]) * pq.s

        ana = AnalogSignal(np.random.random(self.n_samples),
                           sampling_rate=self.n_samples / t_stop,
                           units='V',
                           name='ana1',
                           description='LFP')
        assert t_stop == ana.t_stop
        seg.analogsignals.append(ana)
        epo = Epoch(np.random.random(self.n_samples),
                    durations=[1] * self.n_samples * pq.s,
                    units='s',
                    name='epo1')
        seg.epochs.append(epo)
        self.blk = blk
Example #26
0
    def read_block(
        self,
        lazy=False,
        get_waveforms=True,
        cluster_group=None,
        raw_data_units='uV',
        get_raw_data=False,
    ):
        """
        Reads a block with segments and channel_indexes

        Parameters:
        get_waveforms: bool, default = False
            Wether or not to get the waveforms
        get_raw_data: bool, default = False
            Wether or not to get the raw traces
        raw_data_units: str, default = "uV"
            SI units of the raw trace according to voltage_gain given to klusta
        cluster_group: str, default = None
            Which clusters to load, possibilities are "noise", "unsorted",
            "good", if None all is loaded.
        """
        assert not lazy, 'Do not support lazy'

        blk = Block()
        seg = Segment(file_origin=self.filename)
        blk.segments += [seg]
        for model in self.models:
            group_id = model.channel_group
            group_meta = {'group_id': group_id}
            group_meta.update(model.metadata)
            chx = ChannelIndex(name='channel group #{}'.format(group_id),
                               index=model.channels,
                               **group_meta)
            blk.channel_indexes.append(chx)
            clusters = model.spike_clusters
            for cluster_id in model.cluster_ids:
                meta = model.cluster_metadata[cluster_id]
                if cluster_group is None:
                    pass
                elif cluster_group != meta:
                    continue
                sptr = self.read_spiketrain(cluster_id=cluster_id,
                                            model=model,
                                            get_waveforms=get_waveforms,
                                            raw_data_units=raw_data_units)
                sptr.annotations.update({
                    'cluster_group': meta,
                    'group_id': model.channel_group
                })
                sptr.channel_index = chx
                unit = Unit(cluster_group=meta,
                            group_id=model.channel_group,
                            name='unit #{}'.format(cluster_id))
                unit.spiketrains.append(sptr)
                chx.units.append(unit)
                unit.channel_index = chx
                seg.spiketrains.append(sptr)
            if get_raw_data:
                ana = self.read_analogsignal(model, units=raw_data_units)
                ana.channel_index = chx
                seg.analogsignals.append(ana)

        seg.duration = model.duration * pq.s

        blk.create_many_to_one_relationship()
        return blk
Example #27
0
 def _source_unit_to_neo(self, nix_unit):
     neo_attrs = self._nix_attr_to_neo(nix_unit)
     neo_unit = Unit(**neo_attrs)
     self._object_map[nix_unit.id] = neo_unit
     return neo_unit
Example #28
0
    def read_spiketrain(self):
        # TODO add parameter to allow user to read raw data or not?
        assert (SpikeTrain in self.readable_objects)

        spike_trains = []

        channel_group_files = glob.glob(
            os.path.join(self._path, self._base_filename) + ".[0-9]*")
        for raw_filename in sorted(channel_group_files):
            with open(raw_filename, "rb") as f:
                params = parse_header_and_leave_cursor(f)

                channel_group_index = int(raw_filename.split(".")[-1])
                bytes_per_timestamp = params.get("bytes_per_timestamp", 4)
                bytes_per_sample = params.get("bytes_per_sample", 1)
                num_spikes = params.get("num_spikes", 0)
                num_chans = params.get("num_chans", 1)
                samples_per_spike = params.get("samples_per_spike", 50)
                timebase = int(
                    params.get("timebase", "96000 hz").split(" ")[0]) * pq.Hz
                sampling_rate = params.get("rawrate", 48000) * pq.Hz
                bytes_per_spike_without_timestamp = samples_per_spike * bytes_per_sample
                bytes_per_spike = bytes_per_spike_without_timestamp + bytes_per_timestamp

                timestamp_dtype = ">u" + str(bytes_per_timestamp)
                waveform_dtype = "<i" + str(bytes_per_sample)

                dtype = np.dtype([("times", (timestamp_dtype, 1), 1),
                                  ("waveforms", (waveform_dtype, 1),
                                   samples_per_spike)])

                data = np.fromfile(f,
                                   dtype=dtype,
                                   count=num_spikes * num_chans)
                assert_end_of_data(f)

            # times are saved for each channel
            times = data["times"][::num_chans] / timebase
            assert len(times) == num_spikes
            waveforms = data["waveforms"]
            waveforms = np.reshape(waveforms,
                                   (num_spikes, num_chans, samples_per_spike))
            # TODO HACK !!!! findout if recording is sig - ref or the other
            # way around, this determines the way of the peak which should be
            # possible to set in a parameter e.g. peak='negative'/'positive'
            waveforms = -waveforms.astype(float)

            channel_gain_matrix = np.ones(waveforms.shape)
            for i in range(num_chans):
                channel_gain_matrix[:, i, :] *= self._channel_gain(
                    channel_group_index, i)
            waveforms = scale_analog_signal(waveforms, channel_gain_matrix,
                                            self._adc_fullscale,
                                            bytes_per_sample)

            # TODO get left_sweep form setfile?
            spike_train = SpikeTrain(times,
                                     t_stop=self._duration,
                                     waveforms=waveforms * pq.uV,
                                     sampling_rate=sampling_rate,
                                     left_sweep=0.2 * pq.ms,
                                     **params)
            spike_trains.append(spike_train)
            channel_index = self._channel_group_to_channel_index[
                channel_group_index]
            spike_train.channel_index = channel_index
            unit = Unit(
            )  # TODO unit can have several spiketrains from different segments, not necessarily relevant here though
            unit.spiketrains.append(spike_train)
            channel_index.units.append(unit)

        return spike_trains
Example #29
0
    def read_block(self, lazy=False, cascade=True):
        """Returns a Block containing spike information.

        There is no obvious way to infer the segment boundaries from
        raw spike times, so for now all spike times are returned in one
        big segment. The way around this would be to specify the segment
        boundaries, and then change this code to put the spikes in the right
        segments.
        """
        # Create block and segment to hold all the data
        block = Block()
        # Search data directory for KlustaKwik files.
        # If nothing found, return empty block
        self._fetfiles = self._fp.read_filenames('fet')
        self._clufiles = self._fp.read_filenames('clu')
        if len(self._fetfiles) == 0 or not cascade:
            return block

        # Create a single segment to hold all of the data
        seg = Segment(name='seg0', index=0, file_origin=self.filename)
        block.segments.append(seg)

        # Load spike times from each group and store in a dict, keyed
        # by group number
        self.spiketrains = dict()
        for group in sorted(self._fetfiles.keys()):
            # Load spike times
            fetfile = self._fetfiles[group]
            spks, features = self._load_spike_times(fetfile)

            # Load cluster ids or generate
            if group in self._clufiles:
                clufile = self._clufiles[group]
                uids = self._load_unit_id(clufile)
            else:
                # unclustered data, assume all zeros
                uids = np.zeros(spks.shape, dtype=np.int32)

            # error check
            if len(spks) != len(uids):
                raise ValueError("lengths of fet and clu files are different")

            # Create Unit for each cluster
            unique_unit_ids = np.unique(uids)
            for unit_id in sorted(unique_unit_ids):
                # Initialize the unit
                u = Unit(name=('unit %d from group %d' % (unit_id, group)),
                         index=unit_id,
                         group=group)

                # Initialize a new SpikeTrain for the spikes from this unit
                if lazy:
                    st = SpikeTrain(times=[],
                                    units='sec',
                                    t_start=0.0,
                                    t_stop=spks.max() / self.sampling_rate,
                                    name=('unit %d from group %d' %
                                          (unit_id, group)))
                    st.lazy_shape = len(spks[uids == unit_id])
                else:
                    st = SpikeTrain(
                        times=spks[uids == unit_id] / self.sampling_rate,
                        units='sec',
                        t_start=0.0,
                        t_stop=spks.max() / self.sampling_rate,
                        name=('unit %d from group %d' % (unit_id, group)))
                st.annotations['cluster'] = unit_id
                st.annotations['group'] = group

                # put features in
                if not lazy and len(features) != 0:
                    st.annotations['waveform_features'] = features

                # Link
                u.spiketrains.append(st)
                seg.spiketrains.append(st)

        block.create_many_to_one_relationship()
        return block
Example #30
0
    def read_block(self,
                   block_index=0,
                   lazy=False,
                   signal_group_mode=None,
                   units_group_mode=None,
                   load_waveforms=False):
        """


        :param block_index: int default 0. In case of several block block_index can be specified.

        :param lazy: False by default.

        :param signal_group_mode: 'split-all' or 'group-by-same-units' (default depend IO):
        This control behavior for grouping channels in AnalogSignal.
            * 'split-all': each channel will give an AnalogSignal
            * 'group-by-same-units' all channel sharing the same quantity units ar grouped in
            a 2D AnalogSignal

        :param units_group_mode: 'split-all' or 'all-in-one'(default depend IO)
        This control behavior for grouping Unit in ChannelIndex:
            * 'split-all': each neo.Unit is assigned to a new neo.ChannelIndex
            * 'all-in-one': all neo.Unit are grouped in the same neo.ChannelIndex
              (global spike sorting for instance)

        :param load_waveforms: False by default. Control SpikeTrains.waveforms is None or not.

        """

        if signal_group_mode is None:
            signal_group_mode = self._prefered_signal_group_mode
            if self._prefered_signal_group_mode == 'split-all':
                self.logger.warning("the default signal_group_mode will change from "\
                                "'split-all' to 'group-by-same-units' in next release")

        if units_group_mode is None:
            units_group_mode = self._prefered_units_group_mode

        # annotations
        bl_annotations = dict(self.raw_annotations['blocks'][block_index])
        bl_annotations.pop('segments')
        bl_annotations = check_annotations(bl_annotations)

        bl = Block(**bl_annotations)

        # ChannelIndex are plit in 2 parts:
        #  * some for AnalogSignals
        #  * some for Units

        # ChannelIndex for AnalogSignals
        all_channels = self.header['signal_channels']
        channel_indexes_list = self.get_group_channel_indexes()
        for channel_index in channel_indexes_list:
            for i, (ind_within,
                    ind_abs) in self._make_signal_channel_subgroups(
                        channel_index,
                        signal_group_mode=signal_group_mode).items():
                if signal_group_mode == "split-all":
                    chidx_annotations = self.raw_annotations[
                        'signal_channels'][i]
                elif signal_group_mode == "group-by-same-units":
                    # this should be done with array_annotation soon:
                    keys = list(self.raw_annotations['signal_channels'][
                        ind_abs[0]].keys())
                    # take key from first channel of the group
                    chidx_annotations = {key: [] for key in keys}
                    for j in ind_abs:
                        for key in keys:
                            v = self.raw_annotations['signal_channels'][j].get(
                                key, None)
                            chidx_annotations[key].append(v)
                if 'name' in list(chidx_annotations.keys()):
                    chidx_annotations.pop('name')
                chidx_annotations = check_annotations(chidx_annotations)
                # this should be done with array_annotation soon:
                ch_names = all_channels[ind_abs]['name'].astype('S')
                neo_channel_index = ChannelIndex(
                    index=ind_within,
                    channel_names=ch_names,
                    channel_ids=all_channels[ind_abs]['id'],
                    name='Channel group {}'.format(i),
                    **chidx_annotations)

                bl.channel_indexes.append(neo_channel_index)

        # ChannelIndex and Unit
        # 2 case are possible in neo defifferent IO have choosen one or other:
        #  * All units are grouped in the same ChannelIndex and indexes are all channels:
        #    'all-in-one'
        #  * Each units is assigned to one ChannelIndex: 'split-all'
        # This is kept for compatibility
        unit_channels = self.header['unit_channels']
        if units_group_mode == 'all-in-one':
            if unit_channels.size > 0:
                channel_index = ChannelIndex(index=np.array([], dtype='i'),
                                             name='ChannelIndex for all Unit')
                bl.channel_indexes.append(channel_index)
            for c in range(unit_channels.size):
                unit_annotations = self.raw_annotations['unit_channels'][c]
                unit_annotations = check_annotations(unit_annotations)
                unit = Unit(**unit_annotations)
                channel_index.units.append(unit)

        elif units_group_mode == 'split-all':
            for c in range(len(unit_channels)):
                unit_annotations = self.raw_annotations['unit_channels'][c]
                unit_annotations = check_annotations(unit_annotations)
                unit = Unit(**unit_annotations)
                channel_index = ChannelIndex(index=np.array([], dtype='i'),
                                             name='ChannelIndex for Unit')
                channel_index.units.append(unit)
                bl.channel_indexes.append(channel_index)

        # Read all segments
        for seg_index in range(self.segment_count(block_index)):
            seg = self.read_segment(block_index=block_index,
                                    seg_index=seg_index,
                                    lazy=lazy,
                                    signal_group_mode=signal_group_mode,
                                    load_waveforms=load_waveforms)
            bl.segments.append(seg)

        # create link to other containers ChannelIndex and Units
        for seg in bl.segments:
            for c, anasig in enumerate(seg.analogsignals):
                bl.channel_indexes[c].analogsignals.append(anasig)

            nsig = len(seg.analogsignals)
            for c, sptr in enumerate(seg.spiketrains):
                if units_group_mode == 'all-in-one':
                    bl.channel_indexes[nsig].units[c].spiketrains.append(sptr)
                elif units_group_mode == 'split-all':
                    bl.channel_indexes[nsig +
                                       c].units[0].spiketrains.append(sptr)

        bl.create_many_to_one_relationship()

        return bl
Example #31
0
class TestSegment(unittest.TestCase):
    def setUp(self):
        self.setup_analogsignals()
        self.setup_analogsignalarrays()
        self.setup_epochs()
        self.setup_epocharrays()
        self.setup_events()
        self.setup_eventarrays()
        self.setup_irregularlysampledsignals()
        self.setup_spikes()
        self.setup_spiketrains()

        self.setup_units()
        self.setup_segments()

    def setup_segments(self):
        params = {'testarg2': 'yes', 'testarg3': True}
        self.segment1 = Segment(name='test', description='tester 1',
                                file_origin='test.file',
                                testarg1=1, **params)
        self.segment2 = Segment(name='test', description='tester 2',
                                file_origin='test.file',
                                testarg1=1, **params)
        self.segment1.annotate(testarg1=1.1, testarg0=[1, 2, 3])
        self.segment2.annotate(testarg11=1.1, testarg10=[1, 2, 3])

        self.segment1.analogsignals = self.sig1
        self.segment2.analogsignals = self.sig2

        self.segment1.analogsignalarrays = self.sigarr1
        self.segment2.analogsignalarrays = self.sigarr2

        self.segment1.epochs = self.epoch1
        self.segment2.epochs = self.epoch2

        self.segment1.epocharrays = self.epocharr1
        self.segment2.epocharrays = self.epocharr2

        self.segment1.events = self.event1
        self.segment2.events = self.event2

        self.segment1.eventarrays = self.eventarr1
        self.segment2.eventarrays = self.eventarr2

        self.segment1.irregularlysampledsignals = self.irsig1
        self.segment2.irregularlysampledsignals = self.irsig2

        self.segment1.spikes = self.spike1
        self.segment2.spikes = self.spike2

        self.segment1.spiketrains = self.train1
        self.segment2.spiketrains = self.train2

        create_many_to_one_relationship(self.segment1)
        create_many_to_one_relationship(self.segment2)

    def setup_units(self):
        params = {'testarg2': 'yes', 'testarg3': True}
        self.unit1 = Unit(name='test', description='tester 1',
                          file_origin='test.file',
                          channel_indexes=np.array([1]),
                          testarg1=1, **params)
        self.unit2 = Unit(name='test', description='tester 2',
                          file_origin='test.file',
                          channel_indexes=np.array([2]),
                          testarg1=1, **params)
        self.unit1.annotate(testarg1=1.1, testarg0=[1, 2, 3])
        self.unit2.annotate(testarg11=1.1, testarg10=[1, 2, 3])

        self.unit1train = [self.train1[0], self.train2[1]]
        self.unit2train = [self.train1[1], self.train2[0]]

        self.unit1.spiketrains = self.unit1train
        self.unit2.spiketrains = self.unit2train

        self.unit1spike = [self.spike1[0], self.spike2[1]]
        self.unit2spike = [self.spike1[1], self.spike2[0]]

        self.unit1.spikes = self.unit1spike
        self.unit2.spikes = self.unit2spike

        create_many_to_one_relationship(self.unit1)
        create_many_to_one_relationship(self.unit2)

    def setup_analogsignals(self):
        signame11 = 'analogsignal 1 1'
        signame12 = 'analogsignal 1 2'
        signame21 = 'analogsignal 2 1'
        signame22 = 'analogsignal 2 2'

        sigdata11 = np.arange(0, 10) * pq.mV
        sigdata12 = np.arange(10, 20) * pq.mV
        sigdata21 = np.arange(20, 30) * pq.V
        sigdata22 = np.arange(30, 40) * pq.V

        self.signames1 = [signame11, signame12]
        self.signames2 = [signame21, signame22]
        self.signames = [signame11, signame12, signame21, signame22]

        sig11 = AnalogSignal(sigdata11, name=signame11,
                             channel_index=1, sampling_rate=1*pq.Hz)
        sig12 = AnalogSignal(sigdata12, name=signame12,
                             channel_index=2, sampling_rate=1*pq.Hz)
        sig21 = AnalogSignal(sigdata21, name=signame21,
                             channel_index=1, sampling_rate=1*pq.Hz)
        sig22 = AnalogSignal(sigdata22, name=signame22,
                             channel_index=2, sampling_rate=1*pq.Hz)

        self.sig1 = [sig11, sig12]
        self.sig2 = [sig21, sig22]
        self.sig = [sig11, sig12, sig21, sig22]

        self.chan1sig = [self.sig1[0], self.sig2[0]]
        self.chan2sig = [self.sig1[1], self.sig2[1]]

    def setup_analogsignalarrays(self):
        sigarrname11 = 'analogsignalarray 1 1'
        sigarrname12 = 'analogsignalarray 1 2'
        sigarrname21 = 'analogsignalarray 2 1'
        sigarrname22 = 'analogsignalarray 2 2'

        sigarrdata11 = np.arange(0, 10).reshape(5, 2) * pq.mV
        sigarrdata12 = np.arange(10, 20).reshape(5, 2) * pq.mV
        sigarrdata21 = np.arange(20, 30).reshape(5, 2) * pq.V
        sigarrdata22 = np.arange(30, 40).reshape(5, 2) * pq.V
        sigarrdata112 = np.hstack([sigarrdata11, sigarrdata11]) * pq.mV

        self.sigarrnames1 = [sigarrname11, sigarrname12]
        self.sigarrnames2 = [sigarrname21, sigarrname22, sigarrname11]
        self.sigarrnames = [sigarrname11, sigarrname12,
                            sigarrname21, sigarrname22]

        sigarr11 = AnalogSignalArray(sigarrdata11, name=sigarrname11,
                                     sampling_rate=1*pq.Hz,
                                     channel_index=np.array([1, 2]))
        sigarr12 = AnalogSignalArray(sigarrdata12, name=sigarrname12,
                                     sampling_rate=1*pq.Hz,
                                     channel_index=np.array([2, 1]))
        sigarr21 = AnalogSignalArray(sigarrdata21, name=sigarrname21,
                                     sampling_rate=1*pq.Hz,
                                     channel_index=np.array([1, 2]))
        sigarr22 = AnalogSignalArray(sigarrdata22, name=sigarrname22,
                                     sampling_rate=1*pq.Hz,
                                     channel_index=np.array([2, 1]))
        sigarr23 = AnalogSignalArray(sigarrdata11, name=sigarrname11,
                                     sampling_rate=1*pq.Hz,
                                     channel_index=np.array([1, 2]))
        sigarr112 = AnalogSignalArray(sigarrdata112, name=sigarrname11,
                                      sampling_rate=1*pq.Hz,
                                      channel_index=np.array([1, 2]))

        self.sigarr1 = [sigarr11, sigarr12]
        self.sigarr2 = [sigarr21, sigarr22, sigarr23]
        self.sigarr = [sigarr112, sigarr12, sigarr21, sigarr22]

        self.chan1sigarr1 = [sigarr11[:, 0:1], sigarr12[:, 1:2]]
        self.chan2sigarr1 = [sigarr11[:, 1:2], sigarr12[:, 0:1]]
        self.chan1sigarr2 = [sigarr21[:, 0:1], sigarr22[:, 1:2],
                             sigarr23[:, 0:1]]
        self.chan2sigarr2 = [sigarr21[:, 1:2], sigarr22[:, 0:1],
                             sigarr23[:, 0:1]]

    def setup_epochs(self):
        epochname11 = 'epoch 1 1'
        epochname12 = 'epoch 1 2'
        epochname21 = 'epoch 2 1'
        epochname22 = 'epoch 2 2'

        epochtime11 = 10 * pq.ms
        epochtime12 = 20 * pq.ms
        epochtime21 = 30 * pq.s
        epochtime22 = 40 * pq.s

        epochdur11 = 11 * pq.s
        epochdur12 = 21 * pq.s
        epochdur21 = 31 * pq.ms
        epochdur22 = 41 * pq.ms

        self.epochnames1 = [epochname11, epochname12]
        self.epochnames2 = [epochname21, epochname22]
        self.epochnames = [epochname11, epochname12, epochname21, epochname22]

        epoch11 = Epoch(epochtime11, epochdur11,
                        label=epochname11, name=epochname11, channel_index=1,
                        testattr=True)
        epoch12 = Epoch(epochtime12, epochdur12,
                        label=epochname12, name=epochname12, channel_index=2,
                        testattr=False)
        epoch21 = Epoch(epochtime21, epochdur21,
                        label=epochname21, name=epochname21, channel_index=1)
        epoch22 = Epoch(epochtime22, epochdur22,
                        label=epochname22, name=epochname22, channel_index=2)

        self.epoch1 = [epoch11, epoch12]
        self.epoch2 = [epoch21, epoch22]
        self.epoch = [epoch11, epoch12, epoch21, epoch22]

    def setup_epocharrays(self):
        epocharrname11 = 'epocharr 1 1'
        epocharrname12 = 'epocharr 1 2'
        epocharrname21 = 'epocharr 2 1'
        epocharrname22 = 'epocharr 2 2'

        epocharrtime11 = np.arange(0, 10) * pq.ms
        epocharrtime12 = np.arange(10, 20) * pq.ms
        epocharrtime21 = np.arange(20, 30) * pq.s
        epocharrtime22 = np.arange(30, 40) * pq.s

        epocharrdur11 = np.arange(1, 11) * pq.s
        epocharrdur12 = np.arange(11, 21) * pq.s
        epocharrdur21 = np.arange(21, 31) * pq.ms
        epocharrdur22 = np.arange(31, 41) * pq.ms

        self.epocharrnames1 = [epocharrname11, epocharrname12]
        self.epocharrnames2 = [epocharrname21, epocharrname22]
        self.epocharrnames = [epocharrname11,
                              epocharrname12, epocharrname21, epocharrname22]

        epocharr11 = EpochArray(epocharrtime11, epocharrdur11,
                                label=epocharrname11, name=epocharrname11)
        epocharr12 = EpochArray(epocharrtime12, epocharrdur12,
                                label=epocharrname12, name=epocharrname12)
        epocharr21 = EpochArray(epocharrtime21, epocharrdur21,
                                label=epocharrname21, name=epocharrname21)
        epocharr22 = EpochArray(epocharrtime22, epocharrdur22,
                                label=epocharrname22, name=epocharrname22)

        self.epocharr1 = [epocharr11, epocharr12]
        self.epocharr2 = [epocharr21, epocharr22]
        self.epocharr = [epocharr11, epocharr12, epocharr21, epocharr22]

    def setup_events(self):
        eventname11 = 'event 1 1'
        eventname12 = 'event 1 2'
        eventname21 = 'event 2 1'
        eventname22 = 'event 2 2'

        eventtime11 = 10 * pq.ms
        eventtime12 = 20 * pq.ms
        eventtime21 = 30 * pq.s
        eventtime22 = 40 * pq.s

        self.eventnames1 = [eventname11, eventname12]
        self.eventnames2 = [eventname21, eventname22]
        self.eventnames = [eventname11, eventname12, eventname21, eventname22]

        params1 = {'testattr': True}
        params2 = {'testattr': 5}
        event11 = Event(eventtime11, label=eventname11, name=eventname11,
                        **params1)
        event12 = Event(eventtime12, label=eventname12, name=eventname12,
                        **params2)
        event21 = Event(eventtime21, label=eventname21, name=eventname21)
        event22 = Event(eventtime22, label=eventname22, name=eventname22)

        self.event1 = [event11, event12]
        self.event2 = [event21, event22]
        self.event = [event11, event12, event21, event22]

    def setup_eventarrays(self):
        eventarrname11 = 'eventarr 1 1'
        eventarrname12 = 'eventarr 1 2'
        eventarrname21 = 'eventarr 2 1'
        eventarrname22 = 'eventarr 2 2'

        eventarrtime11 = np.arange(0, 10) * pq.ms
        eventarrtime12 = np.arange(10, 20) * pq.ms
        eventarrtime21 = np.arange(20, 30) * pq.s
        eventarrtime22 = np.arange(30, 40) * pq.s

        self.eventarrnames1 = [eventarrname11, eventarrname12]
        self.eventarrnames2 = [eventarrname21, eventarrname22]
        self.eventarrnames = [eventarrname11,
                              eventarrname12, eventarrname21, eventarrname22]

        eventarr11 = EventArray(eventarrtime11,
                                label=eventarrname11, name=eventarrname11)
        eventarr12 = EventArray(eventarrtime12,
                                label=eventarrname12, name=eventarrname12)
        eventarr21 = EventArray(eventarrtime21,
                                label=eventarrname21, name=eventarrname21)
        eventarr22 = EventArray(eventarrtime22,
                                label=eventarrname22, name=eventarrname22)

        self.eventarr1 = [eventarr11, eventarr12]
        self.eventarr2 = [eventarr21, eventarr22]
        self.eventarr = [eventarr11, eventarr12, eventarr21, eventarr22]

    def setup_irregularlysampledsignals(self):
        irsigname11 = 'irregularsignal 1 1'
        irsigname12 = 'irregularsignal 1 2'
        irsigname21 = 'irregularsignal 2 1'
        irsigname22 = 'irregularsignal 2 2'

        irsigdata11 = np.arange(0, 10) * pq.mA
        irsigdata12 = np.arange(10, 20) * pq.mA
        irsigdata21 = np.arange(20, 30) * pq.A
        irsigdata22 = np.arange(30, 40) * pq.A

        irsigtimes11 = np.arange(0, 10) * pq.ms
        irsigtimes12 = np.arange(10, 20) * pq.ms
        irsigtimes21 = np.arange(20, 30) * pq.s
        irsigtimes22 = np.arange(30, 40) * pq.s

        self.irsignames1 = [irsigname11, irsigname12]
        self.irsignames2 = [irsigname21, irsigname22]
        self.irsignames = [irsigname11, irsigname12, irsigname21, irsigname22]

        irsig11 = IrregularlySampledSignal(irsigtimes11, irsigdata11,
                                           name=irsigname11)
        irsig12 = IrregularlySampledSignal(irsigtimes12, irsigdata12,
                                           name=irsigname12)
        irsig21 = IrregularlySampledSignal(irsigtimes21, irsigdata21,
                                           name=irsigname21)
        irsig22 = IrregularlySampledSignal(irsigtimes22, irsigdata22,
                                           name=irsigname22)

        self.irsig1 = [irsig11, irsig12]
        self.irsig2 = [irsig21, irsig22]
        self.irsig = [irsig11, irsig12, irsig21, irsig22]

    def setup_spikes(self):
        spikename11 = 'spike 1 1'
        spikename12 = 'spike 1 2'
        spikename21 = 'spike 2 1'
        spikename22 = 'spike 2 2'

        spikedata11 = 10 * pq.ms
        spikedata12 = 20 * pq.ms
        spikedata21 = 30 * pq.s
        spikedata22 = 40 * pq.s

        self.spikenames1 = [spikename11, spikename12]
        self.spikenames2 = [spikename21, spikename22]
        self.spikenames = [spikename11, spikename12, spikename21, spikename22]

        spike11 = Spike(spikedata11, t_stop=100*pq.s, name=spikename11)
        spike12 = Spike(spikedata12, t_stop=100*pq.s, name=spikename12)
        spike21 = Spike(spikedata21, t_stop=100*pq.s, name=spikename21)
        spike22 = Spike(spikedata22, t_stop=100*pq.s, name=spikename22)

        self.spike1 = [spike11, spike12]
        self.spike2 = [spike21, spike22]
        self.spike = [spike11, spike12, spike21, spike22]

    def setup_spiketrains(self):
        trainname11 = 'spiketrain 1 1'
        trainname12 = 'spiketrain 1 2'
        trainname21 = 'spiketrain 2 1'
        trainname22 = 'spiketrain 2 2'

        traindata11 = np.arange(0, 10) * pq.ms
        traindata12 = np.arange(10, 20) * pq.ms
        traindata21 = np.arange(20, 30) * pq.s
        traindata22 = np.arange(30, 40) * pq.s

        self.trainnames1 = [trainname11, trainname12]
        self.trainnames2 = [trainname21, trainname22]
        self.trainnames = [trainname11, trainname12, trainname21, trainname22]

        train11 = SpikeTrain(traindata11, t_stop=100*pq.s, name=trainname11)
        train12 = SpikeTrain(traindata12, t_stop=100*pq.s, name=trainname12)
        train21 = SpikeTrain(traindata21, t_stop=100*pq.s, name=trainname21)
        train22 = SpikeTrain(traindata22, t_stop=100*pq.s, name=trainname22)

        self.train1 = [train11, train12]
        self.train2 = [train21, train22]
        self.train = [train11, train12, train21, train22]

    def test_init(self):
        seg = Segment(name='a segment', index=3)
        assert_neo_object_is_compliant(seg)
        self.assertEqual(seg.name, 'a segment')
        self.assertEqual(seg.file_origin, None)
        self.assertEqual(seg.index, 3)

    def test__construct_subsegment_by_unit(self):
        nb_seg = 3
        nb_unit = 7
        unit_with_sig = np.array([0, 2, 5])
        signal_types = ['Vm', 'Conductances']
        sig_len = 100

        #recordingchannelgroups
        rcgs = [RecordingChannelGroup(name='Vm',
                                      channel_indexes=unit_with_sig),
                RecordingChannelGroup(name='Conductance',
                                      channel_indexes=unit_with_sig)]

        # Unit
        all_unit = []
        for u in range(nb_unit):
            un = Unit(name='Unit #%d' % u, channel_indexes=np.array([u]))
            assert_neo_object_is_compliant(un)
            all_unit.append(un)

        blk = Block()
        blk.recordingchannelgroups = rcgs
        for s in range(nb_seg):
            seg = Segment(name='Simulation %s' % s)
            for j in range(nb_unit):
                st = SpikeTrain([1, 2, 3], units='ms',
                                t_start=0., t_stop=10)
                st.unit = all_unit[j]

            for t in signal_types:
                anasigarr = AnalogSignalArray(np.zeros((sig_len,
                                                        len(unit_with_sig))),
                                              units='nA',
                                              sampling_rate=1000.*pq.Hz,
                                              channel_indexes=unit_with_sig)
                seg.analogsignalarrays.append(anasigarr)

        create_many_to_one_relationship(blk)
        for unit in all_unit:
            assert_neo_object_is_compliant(unit)
        for rcg in rcgs:
            assert_neo_object_is_compliant(rcg)
        assert_neo_object_is_compliant(blk)

        # what you want
        newseg = seg.construct_subsegment_by_unit(all_unit[:4])
        assert_neo_object_is_compliant(newseg)

    def test_segment_creation(self):
        assert_neo_object_is_compliant(self.segment1)
        assert_neo_object_is_compliant(self.segment2)
        assert_neo_object_is_compliant(self.unit1)
        assert_neo_object_is_compliant(self.unit2)

        self.assertEqual(self.segment1.name, 'test')
        self.assertEqual(self.segment2.name, 'test')

        self.assertEqual(self.segment1.description, 'tester 1')
        self.assertEqual(self.segment2.description, 'tester 2')

        self.assertEqual(self.segment1.file_origin, 'test.file')
        self.assertEqual(self.segment2.file_origin, 'test.file')

        self.assertEqual(self.segment1.annotations['testarg0'], [1, 2, 3])
        self.assertEqual(self.segment2.annotations['testarg10'], [1, 2, 3])

        self.assertEqual(self.segment1.annotations['testarg1'], 1.1)
        self.assertEqual(self.segment2.annotations['testarg1'], 1)
        self.assertEqual(self.segment2.annotations['testarg11'], 1.1)

        self.assertEqual(self.segment1.annotations['testarg2'], 'yes')
        self.assertEqual(self.segment2.annotations['testarg2'], 'yes')

        self.assertTrue(self.segment1.annotations['testarg3'])
        self.assertTrue(self.segment2.annotations['testarg3'])

        self.assertTrue(hasattr(self.segment1, 'analogsignals'))
        self.assertTrue(hasattr(self.segment2, 'analogsignals'))

        self.assertEqual(len(self.segment1.analogsignals), 2)
        self.assertEqual(len(self.segment2.analogsignals), 2)

        for res, targ in zip(self.segment1.analogsignals, self.sig1):
            assert_arrays_equal(res, targ)
            self.assertEqual(res.name, targ.name)

        for res, targ in zip(self.segment2.analogsignals, self.sig2):
            assert_arrays_equal(res, targ)
            self.assertEqual(res.name, targ.name)

        self.assertTrue(hasattr(self.segment1, 'analogsignalarrays'))
        self.assertTrue(hasattr(self.segment2, 'analogsignalarrays'))

        self.assertEqual(len(self.segment1.analogsignalarrays), 2)
        self.assertEqual(len(self.segment2.analogsignalarrays), 3)

        for res, targ in zip(self.segment1.analogsignalarrays, self.sigarr1):
            assert_arrays_equal(res, targ)
            self.assertEqual(res.name, targ.name)

        for res, targ in zip(self.segment2.analogsignalarrays, self.sigarr2):
            assert_arrays_equal(res, targ)
            self.assertEqual(res.name, targ.name)

        self.assertTrue(hasattr(self.segment1, 'epochs'))
        self.assertTrue(hasattr(self.segment2, 'epochs'))

        self.assertEqual(len(self.segment1.epochs), 2)
        self.assertEqual(len(self.segment2.epochs), 2)

        for res, targ in zip(self.segment1.epochs, self.epoch1):
            self.assertEqual(res.time, targ.time)
            self.assertEqual(res.duration, targ.duration)
            self.assertEqual(res.name, targ.name)

        for res, targ in zip(self.segment2.epochs, self.epoch2):
            self.assertEqual(res.time, targ.time)
            self.assertEqual(res.duration, targ.duration)
            self.assertEqual(res.name, targ.name)

        self.assertTrue(hasattr(self.segment1, 'epocharrays'))
        self.assertTrue(hasattr(self.segment2, 'epocharrays'))

        self.assertEqual(len(self.segment1.epocharrays), 2)
        self.assertEqual(len(self.segment2.epocharrays), 2)

        for res, targ in zip(self.segment1.epocharrays, self.epocharr1):
            assert_arrays_equal(res.times, targ.times)
            assert_arrays_equal(res.durations, targ.durations)
            self.assertEqual(res.name, targ.name)

        for res, targ in zip(self.segment2.epocharrays, self.epocharr2):
            assert_arrays_equal(res.times, targ.times)
            assert_arrays_equal(res.durations, targ.durations)
            self.assertEqual(res.name, targ.name)

        self.assertTrue(hasattr(self.segment1, 'events'))
        self.assertTrue(hasattr(self.segment2, 'events'))

        self.assertEqual(len(self.segment1.events), 2)
        self.assertEqual(len(self.segment2.events), 2)

        for res, targ in zip(self.segment1.events, self.event1):
            self.assertEqual(res.time, targ.time)
            self.assertEqual(res.name, targ.name)

        for res, targ in zip(self.segment2.events, self.event2):
            self.assertEqual(res.time, targ.time)
            self.assertEqual(res.name, targ.name)

        self.assertTrue(hasattr(self.segment1, 'eventarrays'))
        self.assertTrue(hasattr(self.segment2, 'eventarrays'))

        self.assertEqual(len(self.segment1.eventarrays), 2)
        self.assertEqual(len(self.segment2.eventarrays), 2)

        for res, targ in zip(self.segment1.eventarrays, self.eventarr1):
            assert_arrays_equal(res.times, targ.times)
            self.assertEqual(res.name, targ.name)

        for res, targ in zip(self.segment2.eventarrays, self.eventarr2):
            assert_arrays_equal(res.times, targ.times)
            self.assertEqual(res.name, targ.name)

        self.assertTrue(hasattr(self.segment1, 'irregularlysampledsignals'))
        self.assertTrue(hasattr(self.segment2, 'irregularlysampledsignals'))

        self.assertEqual(len(self.segment1.irregularlysampledsignals), 2)
        self.assertEqual(len(self.segment2.irregularlysampledsignals), 2)

        for res, targ in zip(self.segment1.irregularlysampledsignals,
                             self.irsig1):
            assert_arrays_equal(res, targ)
            self.assertEqual(res.name, targ.name)

        for res, targ in zip(self.segment2.irregularlysampledsignals,
                             self.irsig2):
            assert_arrays_equal(res, targ)
            self.assertEqual(res.name, targ.name)

        self.assertTrue(hasattr(self.segment1, 'spikes'))
        self.assertTrue(hasattr(self.segment2, 'spikes'))

        self.assertEqual(len(self.segment1.spikes), 2)
        self.assertEqual(len(self.segment2.spikes), 2)

        for res, targ in zip(self.segment1.spikes, self.spike1):
            self.assertEqual(res, targ)
            self.assertEqual(res.name, targ.name)

        for res, targ in zip(self.segment2.spikes, self.spike2):
            self.assertEqual(res, targ)
            self.assertEqual(res.name, targ.name)

        self.assertTrue(hasattr(self.segment1, 'spiketrains'))
        self.assertTrue(hasattr(self.segment2, 'spiketrains'))

        self.assertEqual(len(self.segment1.spiketrains), 2)
        self.assertEqual(len(self.segment2.spiketrains), 2)

        for res, targ in zip(self.segment1.spiketrains, self.train1):
            assert_arrays_equal(res, targ)
            self.assertEqual(res.name, targ.name)

        for res, targ in zip(self.segment2.spiketrains, self.train2):
            assert_arrays_equal(res, targ)
            self.assertEqual(res.name, targ.name)

    def test_segment_merge(self):
        self.segment1.merge(self.segment2)
        create_many_to_one_relationship(self.segment1, force=True)
        assert_neo_object_is_compliant(self.segment1)

        self.assertEqual(self.segment1.name, 'test')
        self.assertEqual(self.segment2.name, 'test')

        self.assertEqual(self.segment1.description, 'tester 1')
        self.assertEqual(self.segment2.description, 'tester 2')

        self.assertEqual(self.segment1.file_origin, 'test.file')
        self.assertEqual(self.segment2.file_origin, 'test.file')

        self.assertEqual(self.segment1.annotations['testarg0'], [1, 2, 3])
        self.assertEqual(self.segment2.annotations['testarg10'], [1, 2, 3])

        self.assertEqual(self.segment1.annotations['testarg1'], 1.1)
        self.assertEqual(self.segment2.annotations['testarg1'], 1)
        self.assertEqual(self.segment2.annotations['testarg11'], 1.1)

        self.assertEqual(self.segment1.annotations['testarg2'], 'yes')
        self.assertEqual(self.segment2.annotations['testarg2'], 'yes')

        self.assertTrue(self.segment1.annotations['testarg3'])
        self.assertTrue(self.segment2.annotations['testarg3'])

        self.assertTrue(hasattr(self.segment1, 'analogsignals'))
        self.assertTrue(hasattr(self.segment2, 'analogsignals'))

        self.assertEqual(len(self.segment1.analogsignals), 4)
        self.assertEqual(len(self.segment2.analogsignals), 2)

        for res, targ in zip(self.segment1.analogsignals, self.sig):
            assert_arrays_equal(res, targ)
            self.assertEqual(res.name, targ.name)

        for res, targ in zip(self.segment2.analogsignals, self.sig2):
            assert_arrays_equal(res, targ)
            self.assertEqual(res.name, targ.name)

        self.assertTrue(hasattr(self.segment1, 'analogsignalarrays'))
        self.assertTrue(hasattr(self.segment2, 'analogsignalarrays'))

        self.assertEqual(len(self.segment1.analogsignalarrays), 4)
        self.assertEqual(len(self.segment2.analogsignalarrays), 3)

        for res, targ in zip(self.segment1.analogsignalarrays, self.sigarr):
            assert_arrays_equal(res, targ)
            self.assertEqual(res.name, targ.name)

        for res, targ in zip(self.segment2.analogsignalarrays, self.sigarr2):
            assert_arrays_equal(res, targ)
            self.assertEqual(res.name, targ.name)

        self.assertTrue(hasattr(self.segment1, 'epochs'))
        self.assertTrue(hasattr(self.segment2, 'epochs'))

        self.assertEqual(len(self.segment1.epochs), 4)
        self.assertEqual(len(self.segment2.epochs), 2)

        for res, targ in zip(self.segment1.epochs, self.epoch):
            self.assertEqual(res.time, targ.time)
            self.assertEqual(res.duration, targ.duration)
            self.assertEqual(res.name, targ.name)

        for res, targ in zip(self.segment2.epochs, self.epoch2):
            self.assertEqual(res.time, targ.time)
            self.assertEqual(res.duration, targ.duration)
            self.assertEqual(res.name, targ.name)

        self.assertTrue(hasattr(self.segment1, 'epocharrays'))
        self.assertTrue(hasattr(self.segment2, 'epocharrays'))

        self.assertEqual(len(self.segment1.epocharrays), 4)
        self.assertEqual(len(self.segment2.epocharrays), 2)

        for res, targ in zip(self.segment1.epocharrays, self.epocharr):
            assert_arrays_equal(res.times, targ.times)
            assert_arrays_equal(res.durations, targ.durations)
            self.assertEqual(res.name, targ.name)

        for res, targ in zip(self.segment2.epocharrays, self.epocharr2):
            assert_arrays_equal(res.times, targ.times)
            assert_arrays_equal(res.durations, targ.durations)
            self.assertEqual(res.name, targ.name)

        self.assertTrue(hasattr(self.segment1, 'events'))
        self.assertTrue(hasattr(self.segment2, 'events'))

        self.assertEqual(len(self.segment1.events), 4)
        self.assertEqual(len(self.segment2.events), 2)

        for res, targ in zip(self.segment1.events, self.event):
            self.assertEqual(res.time, targ.time)
            self.assertEqual(res.name, targ.name)

        for res, targ in zip(self.segment2.events, self.event2):
            self.assertEqual(res.time, targ.time)
            self.assertEqual(res.name, targ.name)

        self.assertTrue(hasattr(self.segment1, 'eventarrays'))
        self.assertTrue(hasattr(self.segment2, 'eventarrays'))

        self.assertEqual(len(self.segment1.eventarrays), 4)
        self.assertEqual(len(self.segment2.eventarrays), 2)

        for res, targ in zip(self.segment1.eventarrays, self.eventarr):
            assert_arrays_equal(res.times, targ.times)
            self.assertEqual(res.name, targ.name)

        for res, targ in zip(self.segment2.eventarrays, self.eventarr2):
            assert_arrays_equal(res.times, targ.times)
            self.assertEqual(res.name, targ.name)

        self.assertTrue(hasattr(self.segment1, 'irregularlysampledsignals'))
        self.assertTrue(hasattr(self.segment2, 'irregularlysampledsignals'))

        self.assertEqual(len(self.segment1.irregularlysampledsignals), 4)
        self.assertEqual(len(self.segment2.irregularlysampledsignals), 2)

        for res, targ in zip(self.segment1.irregularlysampledsignals,
                             self.irsig):
            assert_arrays_equal(res, targ)
            self.assertEqual(res.name, targ.name)

        for res, targ in zip(self.segment2.irregularlysampledsignals,
                             self.irsig2):
            assert_arrays_equal(res, targ)
            self.assertEqual(res.name, targ.name)

        self.assertTrue(hasattr(self.segment1, 'spikes'))
        self.assertTrue(hasattr(self.segment2, 'spikes'))

        self.assertEqual(len(self.segment1.spikes), 4)
        self.assertEqual(len(self.segment2.spikes), 2)

        for res, targ in zip(self.segment1.spikes, self.spike):
            self.assertEqual(res, targ)
            self.assertEqual(res.name, targ.name)

        for res, targ in zip(self.segment2.spikes, self.spike2):
            self.assertEqual(res, targ)
            self.assertEqual(res.name, targ.name)

        self.assertTrue(hasattr(self.segment1, 'spiketrains'))
        self.assertTrue(hasattr(self.segment2, 'spiketrains'))

        self.assertEqual(len(self.segment1.spiketrains), 4)
        self.assertEqual(len(self.segment2.spiketrains), 2)

        for res, targ in zip(self.segment1.spiketrains, self.train):
            assert_arrays_equal(res, targ)
            self.assertEqual(res.name, targ.name)

        for res, targ in zip(self.segment2.spiketrains, self.train2):
            assert_arrays_equal(res, targ)
            self.assertEqual(res.name, targ.name)

    def test_segment_all_data(self):
        result1 = self.segment1.all_data
        targs = (self.epoch1 + self.epocharr1 + self.event1 + self.eventarr1 +
                 self.sig1 + self.sigarr1 + self.irsig1 +
                 self.spike1 + self.train1)

        for res, targ in zip(result1, targs):
            if hasattr(res, 'ndim') and res.ndim:
                assert_arrays_equal(res, targ)
            else:
                self.assertEqual(res, targ)
            self.assertEqual(res.name, targ.name)

    def test_segment_take_spikes_by_unit(self):
        result1 = self.segment1.take_spikes_by_unit()
        result21 = self.segment1.take_spikes_by_unit([self.unit1])
        result22 = self.segment1.take_spikes_by_unit([self.unit2])

        self.assertEqual(result1, [])

        for res, targ in zip(result21, self.unit1spike):
            self.assertEqual(res, targ)
            self.assertEqual(res.name, targ.name)

        for res, targ in zip(result22, self.unit2spike):
            self.assertEqual(res, targ)
            self.assertEqual(res.name, targ.name)

    def test_segment_take_spiketrains_by_unit(self):
        result1 = self.segment1.take_spiketrains_by_unit()
        result21 = self.segment1.take_spiketrains_by_unit([self.unit1])
        result22 = self.segment1.take_spiketrains_by_unit([self.unit2])

        self.assertEqual(result1, [])

        for res, targ in zip(result21, self.unit1train):
            assert_arrays_equal(res, targ)
            self.assertEqual(res.name, targ.name)

        for res, targ in zip(result22, self.unit2train):
            assert_arrays_equal(res, targ)
            self.assertEqual(res.name, targ.name)

    def test_segment_take_analogsignal_by_unit(self):
        result1 = self.segment1.take_analogsignal_by_unit()
        result21 = self.segment1.take_analogsignal_by_unit([self.unit1])
        result22 = self.segment1.take_analogsignal_by_unit([self.unit2])

        self.assertEqual(result1, [])

        for res, targ in zip(result21, self.chan1sig):
            assert_arrays_equal(res, targ)
            self.assertEqual(res.name, targ.name)

        for res, targ in zip(result22, self.chan2sig):
            assert_arrays_equal(res, targ)
            self.assertEqual(res.name, targ.name)

    def test_segment_take_analogsignal_by_channelindex(self):
        result1 = self.segment1.take_analogsignal_by_channelindex()
        result21 = self.segment1.take_analogsignal_by_channelindex([1])
        result22 = self.segment1.take_analogsignal_by_channelindex([2])

        self.assertEqual(result1, [])

        for res, targ in zip(result21, self.chan1sig):
            assert_arrays_equal(res, targ)
            self.assertEqual(res.name, targ.name)

        for res, targ in zip(result22, self.chan2sig):
            assert_arrays_equal(res, targ)
            self.assertEqual(res.name, targ.name)

    def test_segment_take_slice_of_analogsignalarray_by_unit(self):
        segment = self.segment1
        unit1 = self.unit1
        unit2 = self.unit2

        result1 = segment.take_slice_of_analogsignalarray_by_unit()
        result21 = segment.take_slice_of_analogsignalarray_by_unit([unit1])
        result22 = segment.take_slice_of_analogsignalarray_by_unit([unit2])

        self.assertEqual(result1, [])

        for res, targ in zip(result21, self.chan1sigarr1):
            assert_arrays_equal(res, targ)
            self.assertEqual(res.name, targ.name)

        for res, targ in zip(result22, self.chan2sigarr1):
            assert_arrays_equal(res, targ)
            self.assertEqual(res.name, targ.name)

    def test_segment_take_slice_of_analogsignalarray_by_channelindex(self):
        segment = self.segment1
        result1 = segment.take_slice_of_analogsignalarray_by_channelindex()
        result21 = segment.take_slice_of_analogsignalarray_by_channelindex([1])
        result22 = segment.take_slice_of_analogsignalarray_by_channelindex([2])

        self.assertEqual(result1, [])

        for res, targ in zip(result21, self.chan1sigarr1):
            assert_arrays_equal(res, targ)
            self.assertEqual(res.name, targ.name)

        for res, targ in zip(result22, self.chan2sigarr1):
            assert_arrays_equal(res, targ)
            self.assertEqual(res.name, targ.name)

    def test_segment_size(self):
        result1 = self.segment1.size()
        targ1 = {"epochs": 2,  "events": 2,  "analogsignals": 2,
                 "irregularlysampledsignals": 2, "spikes": 2,
                 "spiketrains": 2, "epocharrays": 2, "eventarrays": 2,
                 "analogsignalarrays": 2}

        self.assertEqual(result1, targ1)

    def test_segment_filter(self):
        result1 = self.segment1.filter()
        result2 = self.segment1.filter(name='analogsignal 1 1')
        result3 = self.segment1.filter(testattr=True)

        self.assertEqual(result1, [])

        self.assertEqual(len(result2), 1)
        assert_arrays_equal(result2[0], self.sig1[0])

        self.assertEqual(len(result3), 2)
        self.assertEqual(result3[0], self.epoch1[0])
        self.assertEqual(result3[1], self.event1[0])
Example #32
0
    def read_block(self, block_index=0, lazy=False, cascade=True, signal_group_mode=None, 
                units_group_mode=None, load_waveforms=False, time_slices=None):
        """
        
        
        :param block_index: int default 0. In case of several block block_index can be specified.
        
        :param lazy: False by default. 
        
        :param cascade: True by Default
        
        :param signal_group_mode: 'split-all' or 'group-by-same-units' (default depend IO):
        This control behavior for grouping channels in AnalogSignal.
            * 'split-all': each channel will give an AnalogSignal
            * 'group-by-same-units' all channel sharing the same quantity units ar grouped in
            a 2D AnalogSignal
        
        :param units_group_mode: 'split-all' or 'all-in-one'(default depend IO)
        This control behavior for grouping Unit in ChannelIndex:
            * 'split-all': each neo.Unit is assigned to a new neo.ChannelIndex
            * 'all-in-one': all neo.Unit are grouped in the same neo.ChannelIndex (global spike sorting for instance)
        
        :param load_waveforms: False by default. Control SpikeTrains.waveforms is None or not.
        
        :param time_slices: None by default. List of time_slice. A time slice is (t_start, t_stop) both are quantities.
            each element will lead to a fake neo.Segment. So len(block.segment) == len(time_slice)
            all time_slice must be compatible with original time range.
        
        """
        
        if signal_group_mode is None:
            signal_group_mode = self._prefered_signal_group_mode

        if units_group_mode is None:
            units_group_mode = self._prefered_units_group_mode

        #annotations
        bl_annotations = dict(self.raw_annotations['blocks'][block_index])
        bl_annotations.pop('segments')
        bl_annotations = check_annotations(bl_annotations)

        bl = Block(**bl_annotations)
        
        if not cascade:
            return bl
        
        #ChannelIndex are plit in 2 parts:
        #  * some for AnalogSignals
        #  * some for Units
        
        #ChannelIndex ofr AnalogSignals
        all_channels = self.header['signal_channels']
        channel_indexes_list = self.get_group_channel_indexes()
        for channel_index in channel_indexes_list:
            for i, (ind_within, ind_abs) in self._make_signal_channel_subgroups(channel_index, 
                                                        signal_group_mode=signal_group_mode).items():
                neo_channel_index = ChannelIndex(index=ind_within, channel_names=all_channels[ind_abs]['name'].astype('S'),
                                channel_ids=all_channels[ind_abs]['id'], name='Channel group {}'.format(i))
                bl.channel_indexes.append(neo_channel_index)
        
        #ChannelIndex and Unit
        #2 case are possible in neo defifferent IO have choosen one or other:
        #  * All units are group in the same ChannelIndex and indexes are all channels : 'all-in-one'
        #  * Each units is assigned to one ChannelIndex : 'split-all'
        # This is kept for compatibility
        unit_channels = self.header['unit_channels']
        if units_group_mode=='all-in-one':
            if unit_channels.size>0:
                channel_index = ChannelIndex(index=np.array([], dtype='i'),
                                        name='ChannelIndex for all Unit')
                bl.channel_indexes.append(channel_index)
            for c in range(unit_channels.size):
                unit_annotations = self.raw_annotations['unit_channels'][c]
                unit = Unit(**unit_annotations)
                channel_index.units.append(unit)
                
        elif units_group_mode=='split-all':
            for c in range(len(unit_channels)):
                unit_annotations = self.raw_annotations['unit_channels'][c]
                unit = Unit(**unit_annotations)
                channel_index = ChannelIndex(index=np.array([], dtype='i'),
                                        name='ChannelIndex for Unit')
                channel_index.units.append(unit)
                bl.channel_indexes.append(channel_index)
        
        if time_slices is None:
            #Read the real segment counts
            for seg_index in range(self.segment_count(block_index)):
                seg =  self.read_segment(block_index=block_index, seg_index=seg_index, 
                                                                    lazy=lazy, cascade=cascade, signal_group_mode=signal_group_mode,
                                                                    load_waveforms=load_waveforms)
                bl.segments.append(seg)
                
        else:
            #return a fake segment list corresponding to time_slices
            for s, time_slice in enumerate(time_slices):
                #find in which segment time_slice is
                t_start, t_stop = time_slice
                t_start = ensure_second(t_start)
                t_stop = ensure_second(t_stop)
                related_seg_index = None
                for seg_index in range(self.segment_count(block_index)):
                    seg_t_start = self.segment_t_start(block_index, seg_index) * pq.s
                    seg_t_stop = self.segment_t_stop(block_index, seg_index) * pq.s
                    if (seg_t_start<=t_start<=seg_t_stop) and (seg_t_start<=t_stop<=seg_t_stop):
                        related_seg_index = seg_index
                
                if related_seg_index is None:
                    raise(ValueError('time_slice not in any segment range  {}'.format(time_slice)))
                
                seg =  self.read_segment(block_index=block_index, seg_index=related_seg_index,
                                                                    lazy=lazy, cascade=cascade, signal_group_mode=signal_group_mode,
                                                                    load_waveforms=load_waveforms, time_slice=time_slice)
                seg.index = s
                bl.segments.append(seg)
                
                for c, anasig in enumerate(seg.analogsignals):
                    bl.channel_indexes[c].analogsignals.append(anasig)
        
        #create link to other containers ChannelIndex and Units
        for seg in bl.segments:
            for c, anasig in enumerate(seg.analogsignals):
                bl.channel_indexes[c].analogsignals.append(anasig)
            
            nsig = len(seg.analogsignals)
            for c, sptr in enumerate(seg.spiketrains):
                if units_group_mode=='all-in-one':
                    bl.channel_indexes[nsig].units[c].spiketrains.append(sptr)
                elif units_group_mode=='split-all':
                    bl.channel_indexes[nsig+c].units[0].spiketrains.append(sptr)
        
        bl.create_many_to_one_relationship()
        
        return bl
Example #33
0
    def test__issue_285(self):
        # Spiketrain
        train = SpikeTrain([3, 4, 5] * pq.s, t_stop=10.0)
        unit = Unit()
        train.unit = unit
        unit.spiketrains.append(train)

        epoch = Epoch(np.array([0, 10, 20]),
                      np.array([2, 2, 2]),
                      np.array(["a", "b", "c"]),
                      units="ms")

        blk = Block()
        seg = Segment()
        seg.spiketrains.append(train)
        seg.epochs.append(epoch)
        epoch.segment = seg
        blk.segments.append(seg)

        reader = PickleIO(filename="blk.pkl")
        reader.write(blk)

        reader = PickleIO(filename="blk.pkl")
        r_blk = reader.read_block()
        r_seg = r_blk.segments[0]
        self.assertIsInstance(r_seg.spiketrains[0].unit, Unit)
        self.assertIsInstance(r_seg.epochs[0], Epoch)
        os.remove('blk.pkl')

        # Epoch
        epoch = Epoch(times=np.arange(0, 30, 10) * pq.s,
                      durations=[10, 5, 7] * pq.ms,
                      labels=np.array(['btn0', 'btn1', 'btn2'], dtype='U'))
        epoch.segment = Segment()
        blk = Block()
        seg = Segment()
        seg.epochs.append(epoch)
        blk.segments.append(seg)

        reader = PickleIO(filename="blk.pkl")
        reader.write(blk)

        reader = PickleIO(filename="blk.pkl")
        r_blk = reader.read_block()
        r_seg = r_blk.segments[0]
        self.assertIsInstance(r_seg.epochs[0].segment, Segment)
        os.remove('blk.pkl')

        # Event
        event = Event(np.arange(0, 30, 10) * pq.s,
                      labels=np.array(['trig0', 'trig1', 'trig2'], dtype='U'))
        event.segment = Segment()

        blk = Block()
        seg = Segment()
        seg.events.append(event)
        blk.segments.append(seg)

        reader = PickleIO(filename="blk.pkl")
        reader.write(blk)

        reader = PickleIO(filename="blk.pkl")
        r_blk = reader.read_block()
        r_seg = r_blk.segments[0]
        self.assertIsInstance(r_seg.events[0].segment, Segment)
        os.remove('blk.pkl')

        # IrregularlySampledSignal
        signal = IrregularlySampledSignal([0.0, 1.23, 6.78], [1, 2, 3],
                                          units='mV',
                                          time_units='ms')
        signal.segment = Segment()

        blk = Block()
        seg = Segment()
        seg.irregularlysampledsignals.append(signal)
        blk.segments.append(seg)
        blk.segments[0].block = blk

        reader = PickleIO(filename="blk.pkl")
        reader.write(blk)

        reader = PickleIO(filename="blk.pkl")
        r_blk = reader.read_block()
        r_seg = r_blk.segments[0]
        self.assertIsInstance(r_seg.irregularlysampledsignals[0].segment,
                              Segment)
        os.remove('blk.pkl')
Example #34
0
class TestSegment(unittest.TestCase):
    def setUp(self):
        self.setup_analogsignals()
        self.setup_analogsignalarrays()
        self.setup_epochs()
        self.setup_epocharrays()
        self.setup_events()
        self.setup_eventarrays()
        self.setup_irregularlysampledsignals()
        self.setup_spikes()
        self.setup_spiketrains()

        self.setup_units()
        self.setup_segments()

    def setup_segments(self):
        params = {'testarg2': 'yes', 'testarg3': True}
        self.segment1 = Segment(name='test', description='tester 1',
                                file_origin='test.file',
                                testarg1=1, **params)
        self.segment2 = Segment(name='test', description='tester 2',
                                file_origin='test.file',
                                testarg1=1, **params)
        self.segment1.annotate(testarg1=1.1, testarg0=[1, 2, 3])
        self.segment2.annotate(testarg11=1.1, testarg10=[1, 2, 3])

        self.segment1.analogsignals = self.sig1
        self.segment2.analogsignals = self.sig2

        self.segment1.analogsignalarrays = self.sigarr1
        self.segment2.analogsignalarrays = self.sigarr2

        self.segment1.epochs = self.epoch1
        self.segment2.epochs = self.epoch2

        self.segment1.epocharrays = self.epocharr1
        self.segment2.epocharrays = self.epocharr2

        self.segment1.events = self.event1
        self.segment2.events = self.event2

        self.segment1.eventarrays = self.eventarr1
        self.segment2.eventarrays = self.eventarr2

        self.segment1.irregularlysampledsignals = self.irsig1
        self.segment2.irregularlysampledsignals = self.irsig2

        self.segment1.spikes = self.spike1
        self.segment2.spikes = self.spike2

        self.segment1.spiketrains = self.train1
        self.segment2.spiketrains = self.train2

        self.segment1.create_many_to_one_relationship()
        self.segment2.create_many_to_one_relationship()

    def setup_units(self):
        params = {'testarg2': 'yes', 'testarg3': True}
        self.unit1 = Unit(name='test', description='tester 1',
                          file_origin='test.file',
                          channel_indexes=np.array([1]),
                          testarg1=1, **params)
        self.unit2 = Unit(name='test', description='tester 2',
                          file_origin='test.file',
                          channel_indexes=np.array([2]),
                          testarg1=1, **params)
        self.unit1.annotate(testarg1=1.1, testarg0=[1, 2, 3])
        self.unit2.annotate(testarg11=1.1, testarg10=[1, 2, 3])

        self.unit1train = [self.train1[0], self.train2[1]]
        self.unit2train = [self.train1[1], self.train2[0]]

        self.unit1.spiketrains = self.unit1train
        self.unit2.spiketrains = self.unit2train

        self.unit1spike = [self.spike1[0], self.spike2[1]]
        self.unit2spike = [self.spike1[1], self.spike2[0]]

        self.unit1.spikes = self.unit1spike
        self.unit2.spikes = self.unit2spike

        self.unit1.create_many_to_one_relationship()
        self.unit2.create_many_to_one_relationship()

    def setup_analogsignals(self):
        signame11 = 'analogsignal 1 1'
        signame12 = 'analogsignal 1 2'
        signame21 = 'analogsignal 2 1'
        signame22 = 'analogsignal 2 2'

        sigdata11 = np.arange(0, 10) * pq.mV
        sigdata12 = np.arange(10, 20) * pq.mV
        sigdata21 = np.arange(20, 30) * pq.V
        sigdata22 = np.arange(30, 40) * pq.V

        self.signames1 = [signame11, signame12]
        self.signames2 = [signame21, signame22]
        self.signames = [signame11, signame12, signame21, signame22]

        sig11 = AnalogSignal(sigdata11, name=signame11,
                             channel_index=1, sampling_rate=1*pq.Hz)
        sig12 = AnalogSignal(sigdata12, name=signame12,
                             channel_index=2, sampling_rate=1*pq.Hz)
        sig21 = AnalogSignal(sigdata21, name=signame21,
                             channel_index=1, sampling_rate=1*pq.Hz)
        sig22 = AnalogSignal(sigdata22, name=signame22,
                             channel_index=2, sampling_rate=1*pq.Hz)

        self.sig1 = [sig11, sig12]
        self.sig2 = [sig21, sig22]
        self.sig = [sig11, sig12, sig21, sig22]

        self.chan1sig = [self.sig1[0], self.sig2[0]]
        self.chan2sig = [self.sig1[1], self.sig2[1]]

    def setup_analogsignalarrays(self):
        sigarrname11 = 'analogsignalarray 1 1'
        sigarrname12 = 'analogsignalarray 1 2'
        sigarrname21 = 'analogsignalarray 2 1'
        sigarrname22 = 'analogsignalarray 2 2'

        sigarrdata11 = np.arange(0, 10).reshape(5, 2) * pq.mV
        sigarrdata12 = np.arange(10, 20).reshape(5, 2) * pq.mV
        sigarrdata21 = np.arange(20, 30).reshape(5, 2) * pq.V
        sigarrdata22 = np.arange(30, 40).reshape(5, 2) * pq.V
        sigarrdata112 = np.hstack([sigarrdata11, sigarrdata11]) * pq.mV

        self.sigarrnames1 = [sigarrname11, sigarrname12]
        self.sigarrnames2 = [sigarrname21, sigarrname22, sigarrname11]
        self.sigarrnames = [sigarrname11, sigarrname12,
                            sigarrname21, sigarrname22]

        sigarr11 = AnalogSignalArray(sigarrdata11, name=sigarrname11,
                                     sampling_rate=1*pq.Hz,
                                     channel_index=np.array([1, 2]))
        sigarr12 = AnalogSignalArray(sigarrdata12, name=sigarrname12,
                                     sampling_rate=1*pq.Hz,
                                     channel_index=np.array([2, 1]))
        sigarr21 = AnalogSignalArray(sigarrdata21, name=sigarrname21,
                                     sampling_rate=1*pq.Hz,
                                     channel_index=np.array([1, 2]))
        sigarr22 = AnalogSignalArray(sigarrdata22, name=sigarrname22,
                                     sampling_rate=1*pq.Hz,
                                     channel_index=np.array([2, 1]))
        sigarr23 = AnalogSignalArray(sigarrdata11, name=sigarrname11,
                                     sampling_rate=1*pq.Hz,
                                     channel_index=np.array([1, 2]))
        sigarr112 = AnalogSignalArray(sigarrdata112, name=sigarrname11,
                                      sampling_rate=1*pq.Hz,
                                      channel_index=np.array([1, 2]))

        self.sigarr1 = [sigarr11, sigarr12]
        self.sigarr2 = [sigarr21, sigarr22, sigarr23]
        self.sigarr = [sigarr112, sigarr12, sigarr21, sigarr22]

        self.chan1sigarr1 = [sigarr11[:, 0:1], sigarr12[:, 1:2]]
        self.chan2sigarr1 = [sigarr11[:, 1:2], sigarr12[:, 0:1]]
        self.chan1sigarr2 = [sigarr21[:, 0:1], sigarr22[:, 1:2],
                             sigarr23[:, 0:1]]
        self.chan2sigarr2 = [sigarr21[:, 1:2], sigarr22[:, 0:1],
                             sigarr23[:, 0:1]]

    def setup_epochs(self):
        epochname11 = 'epoch 1 1'
        epochname12 = 'epoch 1 2'
        epochname21 = 'epoch 2 1'
        epochname22 = 'epoch 2 2'

        epochtime11 = 10 * pq.ms
        epochtime12 = 20 * pq.ms
        epochtime21 = 30 * pq.s
        epochtime22 = 40 * pq.s

        epochdur11 = 11 * pq.s
        epochdur12 = 21 * pq.s
        epochdur21 = 31 * pq.ms
        epochdur22 = 41 * pq.ms

        self.epochnames1 = [epochname11, epochname12]
        self.epochnames2 = [epochname21, epochname22]
        self.epochnames = [epochname11, epochname12, epochname21, epochname22]

        epoch11 = Epoch(epochtime11, epochdur11,
                        label=epochname11, name=epochname11, channel_index=1,
                        testattr=True)
        epoch12 = Epoch(epochtime12, epochdur12,
                        label=epochname12, name=epochname12, channel_index=2,
                        testattr=False)
        epoch21 = Epoch(epochtime21, epochdur21,
                        label=epochname21, name=epochname21, channel_index=1)
        epoch22 = Epoch(epochtime22, epochdur22,
                        label=epochname22, name=epochname22, channel_index=2)

        self.epoch1 = [epoch11, epoch12]
        self.epoch2 = [epoch21, epoch22]
        self.epoch = [epoch11, epoch12, epoch21, epoch22]

    def setup_epocharrays(self):
        epocharrname11 = 'epocharr 1 1'
        epocharrname12 = 'epocharr 1 2'
        epocharrname21 = 'epocharr 2 1'
        epocharrname22 = 'epocharr 2 2'

        epocharrtime11 = np.arange(0, 10) * pq.ms
        epocharrtime12 = np.arange(10, 20) * pq.ms
        epocharrtime21 = np.arange(20, 30) * pq.s
        epocharrtime22 = np.arange(30, 40) * pq.s

        epocharrdur11 = np.arange(1, 11) * pq.s
        epocharrdur12 = np.arange(11, 21) * pq.s
        epocharrdur21 = np.arange(21, 31) * pq.ms
        epocharrdur22 = np.arange(31, 41) * pq.ms

        self.epocharrnames1 = [epocharrname11, epocharrname12]
        self.epocharrnames2 = [epocharrname21, epocharrname22]
        self.epocharrnames = [epocharrname11,
                              epocharrname12, epocharrname21, epocharrname22]

        epocharr11 = EpochArray(epocharrtime11, epocharrdur11,
                                label=epocharrname11, name=epocharrname11)
        epocharr12 = EpochArray(epocharrtime12, epocharrdur12,
                                label=epocharrname12, name=epocharrname12)
        epocharr21 = EpochArray(epocharrtime21, epocharrdur21,
                                label=epocharrname21, name=epocharrname21)
        epocharr22 = EpochArray(epocharrtime22, epocharrdur22,
                                label=epocharrname22, name=epocharrname22)

        self.epocharr1 = [epocharr11, epocharr12]
        self.epocharr2 = [epocharr21, epocharr22]
        self.epocharr = [epocharr11, epocharr12, epocharr21, epocharr22]

    def setup_events(self):
        eventname11 = 'event 1 1'
        eventname12 = 'event 1 2'
        eventname21 = 'event 2 1'
        eventname22 = 'event 2 2'

        eventtime11 = 10 * pq.ms
        eventtime12 = 20 * pq.ms
        eventtime21 = 30 * pq.s
        eventtime22 = 40 * pq.s

        self.eventnames1 = [eventname11, eventname12]
        self.eventnames2 = [eventname21, eventname22]
        self.eventnames = [eventname11, eventname12, eventname21, eventname22]

        params1 = {'testattr': True}
        params2 = {'testattr': 5}
        event11 = Event(eventtime11, label=eventname11, name=eventname11,
                        **params1)
        event12 = Event(eventtime12, label=eventname12, name=eventname12,
                        **params2)
        event21 = Event(eventtime21, label=eventname21, name=eventname21)
        event22 = Event(eventtime22, label=eventname22, name=eventname22)

        self.event1 = [event11, event12]
        self.event2 = [event21, event22]
        self.event = [event11, event12, event21, event22]

    def setup_eventarrays(self):
        eventarrname11 = 'eventarr 1 1'
        eventarrname12 = 'eventarr 1 2'
        eventarrname21 = 'eventarr 2 1'
        eventarrname22 = 'eventarr 2 2'

        eventarrtime11 = np.arange(0, 10) * pq.ms
        eventarrtime12 = np.arange(10, 20) * pq.ms
        eventarrtime21 = np.arange(20, 30) * pq.s
        eventarrtime22 = np.arange(30, 40) * pq.s

        self.eventarrnames1 = [eventarrname11, eventarrname12]
        self.eventarrnames2 = [eventarrname21, eventarrname22]
        self.eventarrnames = [eventarrname11,
                              eventarrname12, eventarrname21, eventarrname22]

        eventarr11 = EventArray(eventarrtime11,
                                label=eventarrname11, name=eventarrname11)
        eventarr12 = EventArray(eventarrtime12,
                                label=eventarrname12, name=eventarrname12)
        eventarr21 = EventArray(eventarrtime21,
                                label=eventarrname21, name=eventarrname21)
        eventarr22 = EventArray(eventarrtime22,
                                label=eventarrname22, name=eventarrname22)

        self.eventarr1 = [eventarr11, eventarr12]
        self.eventarr2 = [eventarr21, eventarr22]
        self.eventarr = [eventarr11, eventarr12, eventarr21, eventarr22]

    def setup_irregularlysampledsignals(self):
        irsigname11 = 'irregularsignal 1 1'
        irsigname12 = 'irregularsignal 1 2'
        irsigname21 = 'irregularsignal 2 1'
        irsigname22 = 'irregularsignal 2 2'

        irsigdata11 = np.arange(0, 10) * pq.mA
        irsigdata12 = np.arange(10, 20) * pq.mA
        irsigdata21 = np.arange(20, 30) * pq.A
        irsigdata22 = np.arange(30, 40) * pq.A

        irsigtimes11 = np.arange(0, 10) * pq.ms
        irsigtimes12 = np.arange(10, 20) * pq.ms
        irsigtimes21 = np.arange(20, 30) * pq.s
        irsigtimes22 = np.arange(30, 40) * pq.s

        self.irsignames1 = [irsigname11, irsigname12]
        self.irsignames2 = [irsigname21, irsigname22]
        self.irsignames = [irsigname11, irsigname12, irsigname21, irsigname22]

        irsig11 = IrregularlySampledSignal(irsigtimes11, irsigdata11,
                                           name=irsigname11)
        irsig12 = IrregularlySampledSignal(irsigtimes12, irsigdata12,
                                           name=irsigname12)
        irsig21 = IrregularlySampledSignal(irsigtimes21, irsigdata21,
                                           name=irsigname21)
        irsig22 = IrregularlySampledSignal(irsigtimes22, irsigdata22,
                                           name=irsigname22)

        self.irsig1 = [irsig11, irsig12]
        self.irsig2 = [irsig21, irsig22]
        self.irsig = [irsig11, irsig12, irsig21, irsig22]

    def setup_spikes(self):
        spikename11 = 'spike 1 1'
        spikename12 = 'spike 1 2'
        spikename21 = 'spike 2 1'
        spikename22 = 'spike 2 2'

        spikedata11 = 10 * pq.ms
        spikedata12 = 20 * pq.ms
        spikedata21 = 30 * pq.s
        spikedata22 = 40 * pq.s

        self.spikenames1 = [spikename11, spikename12]
        self.spikenames2 = [spikename21, spikename22]
        self.spikenames = [spikename11, spikename12, spikename21, spikename22]

        spike11 = Spike(spikedata11, t_stop=100*pq.s, name=spikename11)
        spike12 = Spike(spikedata12, t_stop=100*pq.s, name=spikename12)
        spike21 = Spike(spikedata21, t_stop=100*pq.s, name=spikename21)
        spike22 = Spike(spikedata22, t_stop=100*pq.s, name=spikename22)

        self.spike1 = [spike11, spike12]
        self.spike2 = [spike21, spike22]
        self.spike = [spike11, spike12, spike21, spike22]

    def setup_spiketrains(self):
        trainname11 = 'spiketrain 1 1'
        trainname12 = 'spiketrain 1 2'
        trainname21 = 'spiketrain 2 1'
        trainname22 = 'spiketrain 2 2'

        traindata11 = np.arange(0, 10) * pq.ms
        traindata12 = np.arange(10, 20) * pq.ms
        traindata21 = np.arange(20, 30) * pq.s
        traindata22 = np.arange(30, 40) * pq.s

        self.trainnames1 = [trainname11, trainname12]
        self.trainnames2 = [trainname21, trainname22]
        self.trainnames = [trainname11, trainname12, trainname21, trainname22]

        train11 = SpikeTrain(traindata11, t_stop=100*pq.s, name=trainname11)
        train12 = SpikeTrain(traindata12, t_stop=100*pq.s, name=trainname12)
        train21 = SpikeTrain(traindata21, t_stop=100*pq.s, name=trainname21)
        train22 = SpikeTrain(traindata22, t_stop=100*pq.s, name=trainname22)

        self.train1 = [train11, train12]
        self.train2 = [train21, train22]
        self.train = [train11, train12, train21, train22]

    def test_init(self):
        seg = Segment(name='a segment', index=3)
        assert_neo_object_is_compliant(seg)
        self.assertEqual(seg.name, 'a segment')
        self.assertEqual(seg.file_origin, None)
        self.assertEqual(seg.index, 3)

    def test__construct_subsegment_by_unit(self):
        nb_seg = 3
        nb_unit = 7
        unit_with_sig = np.array([0, 2, 5])
        signal_types = ['Vm', 'Conductances']
        sig_len = 100

        #recordingchannelgroups
        rcgs = [RecordingChannelGroup(name='Vm',
                                      channel_indexes=unit_with_sig),
                RecordingChannelGroup(name='Conductance',
                                      channel_indexes=unit_with_sig)]

        # Unit
        all_unit = []
        for u in range(nb_unit):
            un = Unit(name='Unit #%d' % u, channel_indexes=np.array([u]))
            assert_neo_object_is_compliant(un)
            all_unit.append(un)

        blk = Block()
        blk.recordingchannelgroups = rcgs
        for s in range(nb_seg):
            seg = Segment(name='Simulation %s' % s)
            for j in range(nb_unit):
                st = SpikeTrain([1, 2, 3], units='ms',
                                t_start=0., t_stop=10)
                st.unit = all_unit[j]

            for t in signal_types:
                anasigarr = AnalogSignalArray(np.zeros((sig_len,
                                                        len(unit_with_sig))),
                                              units='nA',
                                              sampling_rate=1000.*pq.Hz,
                                              channel_indexes=unit_with_sig)
                seg.analogsignalarrays.append(anasigarr)

        blk.create_many_to_one_relationship()
        for unit in all_unit:
            assert_neo_object_is_compliant(unit)
        for rcg in rcgs:
            assert_neo_object_is_compliant(rcg)
        assert_neo_object_is_compliant(blk)

        # what you want
        newseg = seg.construct_subsegment_by_unit(all_unit[:4])
        assert_neo_object_is_compliant(newseg)

    def test_segment_creation(self):
        assert_neo_object_is_compliant(self.segment1)
        assert_neo_object_is_compliant(self.segment2)
        assert_neo_object_is_compliant(self.unit1)
        assert_neo_object_is_compliant(self.unit2)

        self.assertEqual(self.segment1.name, 'test')
        self.assertEqual(self.segment2.name, 'test')

        self.assertEqual(self.segment1.description, 'tester 1')
        self.assertEqual(self.segment2.description, 'tester 2')

        self.assertEqual(self.segment1.file_origin, 'test.file')
        self.assertEqual(self.segment2.file_origin, 'test.file')

        self.assertEqual(self.segment1.annotations['testarg0'], [1, 2, 3])
        self.assertEqual(self.segment2.annotations['testarg10'], [1, 2, 3])

        self.assertEqual(self.segment1.annotations['testarg1'], 1.1)
        self.assertEqual(self.segment2.annotations['testarg1'], 1)
        self.assertEqual(self.segment2.annotations['testarg11'], 1.1)

        self.assertEqual(self.segment1.annotations['testarg2'], 'yes')
        self.assertEqual(self.segment2.annotations['testarg2'], 'yes')

        self.assertTrue(self.segment1.annotations['testarg3'])
        self.assertTrue(self.segment2.annotations['testarg3'])

        self.assertTrue(hasattr(self.segment1, 'analogsignals'))
        self.assertTrue(hasattr(self.segment2, 'analogsignals'))

        self.assertEqual(len(self.segment1.analogsignals), 2)
        self.assertEqual(len(self.segment2.analogsignals), 2)

        for res, targ in zip(self.segment1.analogsignals, self.sig1):
            assert_arrays_equal(res, targ)
            self.assertEqual(res.name, targ.name)

        for res, targ in zip(self.segment2.analogsignals, self.sig2):
            assert_arrays_equal(res, targ)
            self.assertEqual(res.name, targ.name)

        self.assertTrue(hasattr(self.segment1, 'analogsignalarrays'))
        self.assertTrue(hasattr(self.segment2, 'analogsignalarrays'))

        self.assertEqual(len(self.segment1.analogsignalarrays), 2)
        self.assertEqual(len(self.segment2.analogsignalarrays), 3)

        for res, targ in zip(self.segment1.analogsignalarrays, self.sigarr1):
            assert_arrays_equal(res, targ)
            self.assertEqual(res.name, targ.name)

        for res, targ in zip(self.segment2.analogsignalarrays, self.sigarr2):
            assert_arrays_equal(res, targ)
            self.assertEqual(res.name, targ.name)

        self.assertTrue(hasattr(self.segment1, 'epochs'))
        self.assertTrue(hasattr(self.segment2, 'epochs'))

        self.assertEqual(len(self.segment1.epochs), 2)
        self.assertEqual(len(self.segment2.epochs), 2)

        for res, targ in zip(self.segment1.epochs, self.epoch1):
            self.assertEqual(res.time, targ.time)
            self.assertEqual(res.duration, targ.duration)
            self.assertEqual(res.name, targ.name)

        for res, targ in zip(self.segment2.epochs, self.epoch2):
            self.assertEqual(res.time, targ.time)
            self.assertEqual(res.duration, targ.duration)
            self.assertEqual(res.name, targ.name)

        self.assertTrue(hasattr(self.segment1, 'epocharrays'))
        self.assertTrue(hasattr(self.segment2, 'epocharrays'))

        self.assertEqual(len(self.segment1.epocharrays), 2)
        self.assertEqual(len(self.segment2.epocharrays), 2)

        for res, targ in zip(self.segment1.epocharrays, self.epocharr1):
            assert_arrays_equal(res.times, targ.times)
            assert_arrays_equal(res.durations, targ.durations)
            self.assertEqual(res.name, targ.name)

        for res, targ in zip(self.segment2.epocharrays, self.epocharr2):
            assert_arrays_equal(res.times, targ.times)
            assert_arrays_equal(res.durations, targ.durations)
            self.assertEqual(res.name, targ.name)

        self.assertTrue(hasattr(self.segment1, 'events'))
        self.assertTrue(hasattr(self.segment2, 'events'))

        self.assertEqual(len(self.segment1.events), 2)
        self.assertEqual(len(self.segment2.events), 2)

        for res, targ in zip(self.segment1.events, self.event1):
            self.assertEqual(res.time, targ.time)
            self.assertEqual(res.name, targ.name)

        for res, targ in zip(self.segment2.events, self.event2):
            self.assertEqual(res.time, targ.time)
            self.assertEqual(res.name, targ.name)

        self.assertTrue(hasattr(self.segment1, 'eventarrays'))
        self.assertTrue(hasattr(self.segment2, 'eventarrays'))

        self.assertEqual(len(self.segment1.eventarrays), 2)
        self.assertEqual(len(self.segment2.eventarrays), 2)

        for res, targ in zip(self.segment1.eventarrays, self.eventarr1):
            assert_arrays_equal(res.times, targ.times)
            self.assertEqual(res.name, targ.name)

        for res, targ in zip(self.segment2.eventarrays, self.eventarr2):
            assert_arrays_equal(res.times, targ.times)
            self.assertEqual(res.name, targ.name)

        self.assertTrue(hasattr(self.segment1, 'irregularlysampledsignals'))
        self.assertTrue(hasattr(self.segment2, 'irregularlysampledsignals'))

        self.assertEqual(len(self.segment1.irregularlysampledsignals), 2)
        self.assertEqual(len(self.segment2.irregularlysampledsignals), 2)

        for res, targ in zip(self.segment1.irregularlysampledsignals,
                             self.irsig1):
            assert_arrays_equal(res, targ)
            self.assertEqual(res.name, targ.name)

        for res, targ in zip(self.segment2.irregularlysampledsignals,
                             self.irsig2):
            assert_arrays_equal(res, targ)
            self.assertEqual(res.name, targ.name)

        self.assertTrue(hasattr(self.segment1, 'spikes'))
        self.assertTrue(hasattr(self.segment2, 'spikes'))

        self.assertEqual(len(self.segment1.spikes), 2)
        self.assertEqual(len(self.segment2.spikes), 2)

        for res, targ in zip(self.segment1.spikes, self.spike1):
            self.assertEqual(res, targ)
            self.assertEqual(res.name, targ.name)

        for res, targ in zip(self.segment2.spikes, self.spike2):
            self.assertEqual(res, targ)
            self.assertEqual(res.name, targ.name)

        self.assertTrue(hasattr(self.segment1, 'spiketrains'))
        self.assertTrue(hasattr(self.segment2, 'spiketrains'))

        self.assertEqual(len(self.segment1.spiketrains), 2)
        self.assertEqual(len(self.segment2.spiketrains), 2)

        for res, targ in zip(self.segment1.spiketrains, self.train1):
            assert_arrays_equal(res, targ)
            self.assertEqual(res.name, targ.name)

        for res, targ in zip(self.segment2.spiketrains, self.train2):
            assert_arrays_equal(res, targ)
            self.assertEqual(res.name, targ.name)

    def test_segment_merge(self):
        self.segment1.merge(self.segment2)
        self.segment1.create_many_to_one_relationship(force=True)
        assert_neo_object_is_compliant(self.segment1)

        self.assertEqual(self.segment1.name, 'test')
        self.assertEqual(self.segment2.name, 'test')

        self.assertEqual(self.segment1.description, 'tester 1')
        self.assertEqual(self.segment2.description, 'tester 2')

        self.assertEqual(self.segment1.file_origin, 'test.file')
        self.assertEqual(self.segment2.file_origin, 'test.file')

        self.assertEqual(self.segment1.annotations['testarg0'], [1, 2, 3])
        self.assertEqual(self.segment2.annotations['testarg10'], [1, 2, 3])

        self.assertEqual(self.segment1.annotations['testarg1'], 1.1)
        self.assertEqual(self.segment2.annotations['testarg1'], 1)
        self.assertEqual(self.segment2.annotations['testarg11'], 1.1)

        self.assertEqual(self.segment1.annotations['testarg2'], 'yes')
        self.assertEqual(self.segment2.annotations['testarg2'], 'yes')

        self.assertTrue(self.segment1.annotations['testarg3'])
        self.assertTrue(self.segment2.annotations['testarg3'])

        self.assertTrue(hasattr(self.segment1, 'analogsignals'))
        self.assertTrue(hasattr(self.segment2, 'analogsignals'))

        self.assertEqual(len(self.segment1.analogsignals), 4)
        self.assertEqual(len(self.segment2.analogsignals), 2)

        for res, targ in zip(self.segment1.analogsignals, self.sig):
            assert_arrays_equal(res, targ)
            self.assertEqual(res.name, targ.name)

        for res, targ in zip(self.segment2.analogsignals, self.sig2):
            assert_arrays_equal(res, targ)
            self.assertEqual(res.name, targ.name)

        self.assertTrue(hasattr(self.segment1, 'analogsignalarrays'))
        self.assertTrue(hasattr(self.segment2, 'analogsignalarrays'))

        self.assertEqual(len(self.segment1.analogsignalarrays), 4)
        self.assertEqual(len(self.segment2.analogsignalarrays), 3)

        for res, targ in zip(self.segment1.analogsignalarrays, self.sigarr):
            assert_arrays_equal(res, targ)
            self.assertEqual(res.name, targ.name)

        for res, targ in zip(self.segment2.analogsignalarrays, self.sigarr2):
            assert_arrays_equal(res, targ)
            self.assertEqual(res.name, targ.name)

        self.assertTrue(hasattr(self.segment1, 'epochs'))
        self.assertTrue(hasattr(self.segment2, 'epochs'))

        self.assertEqual(len(self.segment1.epochs), 4)
        self.assertEqual(len(self.segment2.epochs), 2)

        for res, targ in zip(self.segment1.epochs, self.epoch):
            self.assertEqual(res.time, targ.time)
            self.assertEqual(res.duration, targ.duration)
            self.assertEqual(res.name, targ.name)

        for res, targ in zip(self.segment2.epochs, self.epoch2):
            self.assertEqual(res.time, targ.time)
            self.assertEqual(res.duration, targ.duration)
            self.assertEqual(res.name, targ.name)

        self.assertTrue(hasattr(self.segment1, 'epocharrays'))
        self.assertTrue(hasattr(self.segment2, 'epocharrays'))

        self.assertEqual(len(self.segment1.epocharrays), 4)
        self.assertEqual(len(self.segment2.epocharrays), 2)

        for res, targ in zip(self.segment1.epocharrays, self.epocharr):
            assert_arrays_equal(res.times, targ.times)
            assert_arrays_equal(res.durations, targ.durations)
            self.assertEqual(res.name, targ.name)

        for res, targ in zip(self.segment2.epocharrays, self.epocharr2):
            assert_arrays_equal(res.times, targ.times)
            assert_arrays_equal(res.durations, targ.durations)
            self.assertEqual(res.name, targ.name)

        self.assertTrue(hasattr(self.segment1, 'events'))
        self.assertTrue(hasattr(self.segment2, 'events'))

        self.assertEqual(len(self.segment1.events), 4)
        self.assertEqual(len(self.segment2.events), 2)

        for res, targ in zip(self.segment1.events, self.event):
            self.assertEqual(res.time, targ.time)
            self.assertEqual(res.name, targ.name)

        for res, targ in zip(self.segment2.events, self.event2):
            self.assertEqual(res.time, targ.time)
            self.assertEqual(res.name, targ.name)

        self.assertTrue(hasattr(self.segment1, 'eventarrays'))
        self.assertTrue(hasattr(self.segment2, 'eventarrays'))

        self.assertEqual(len(self.segment1.eventarrays), 4)
        self.assertEqual(len(self.segment2.eventarrays), 2)

        for res, targ in zip(self.segment1.eventarrays, self.eventarr):
            assert_arrays_equal(res.times, targ.times)
            self.assertEqual(res.name, targ.name)

        for res, targ in zip(self.segment2.eventarrays, self.eventarr2):
            assert_arrays_equal(res.times, targ.times)
            self.assertEqual(res.name, targ.name)

        self.assertTrue(hasattr(self.segment1, 'irregularlysampledsignals'))
        self.assertTrue(hasattr(self.segment2, 'irregularlysampledsignals'))

        self.assertEqual(len(self.segment1.irregularlysampledsignals), 4)
        self.assertEqual(len(self.segment2.irregularlysampledsignals), 2)

        for res, targ in zip(self.segment1.irregularlysampledsignals,
                             self.irsig):
            assert_arrays_equal(res, targ)
            self.assertEqual(res.name, targ.name)

        for res, targ in zip(self.segment2.irregularlysampledsignals,
                             self.irsig2):
            assert_arrays_equal(res, targ)
            self.assertEqual(res.name, targ.name)

        self.assertTrue(hasattr(self.segment1, 'spikes'))
        self.assertTrue(hasattr(self.segment2, 'spikes'))

        self.assertEqual(len(self.segment1.spikes), 4)
        self.assertEqual(len(self.segment2.spikes), 2)

        for res, targ in zip(self.segment1.spikes, self.spike):
            self.assertEqual(res, targ)
            self.assertEqual(res.name, targ.name)

        for res, targ in zip(self.segment2.spikes, self.spike2):
            self.assertEqual(res, targ)
            self.assertEqual(res.name, targ.name)

        self.assertTrue(hasattr(self.segment1, 'spiketrains'))
        self.assertTrue(hasattr(self.segment2, 'spiketrains'))

        self.assertEqual(len(self.segment1.spiketrains), 4)
        self.assertEqual(len(self.segment2.spiketrains), 2)

        for res, targ in zip(self.segment1.spiketrains, self.train):
            assert_arrays_equal(res, targ)
            self.assertEqual(res.name, targ.name)

        for res, targ in zip(self.segment2.spiketrains, self.train2):
            assert_arrays_equal(res, targ)
            self.assertEqual(res.name, targ.name)

    def test_segment_all_data(self):
        result1 = self.segment1.all_data
        targs = (self.epoch1 + self.epocharr1 + self.event1 + self.eventarr1 +
                 self.sig1 + self.sigarr1 + self.irsig1 +
                 self.spike1 + self.train1)

        for res, targ in zip(result1, targs):
            if hasattr(res, 'ndim') and res.ndim:
                assert_arrays_equal(res, targ)
            else:
                self.assertEqual(res, targ)
            self.assertEqual(res.name, targ.name)

    def test_segment_take_spikes_by_unit(self):
        result1 = self.segment1.take_spikes_by_unit()
        result21 = self.segment1.take_spikes_by_unit([self.unit1])
        result22 = self.segment1.take_spikes_by_unit([self.unit2])

        self.assertEqual(result1, [])

        for res, targ in zip(result21, self.unit1spike):
            self.assertEqual(res, targ)
            self.assertEqual(res.name, targ.name)

        for res, targ in zip(result22, self.unit2spike):
            self.assertEqual(res, targ)
            self.assertEqual(res.name, targ.name)

    def test_segment_take_spiketrains_by_unit(self):
        result1 = self.segment1.take_spiketrains_by_unit()
        result21 = self.segment1.take_spiketrains_by_unit([self.unit1])
        result22 = self.segment1.take_spiketrains_by_unit([self.unit2])

        self.assertEqual(result1, [])

        for res, targ in zip(result21, self.unit1train):
            assert_arrays_equal(res, targ)
            self.assertEqual(res.name, targ.name)

        for res, targ in zip(result22, self.unit2train):
            assert_arrays_equal(res, targ)
            self.assertEqual(res.name, targ.name)

    def test_segment_take_analogsignal_by_unit(self):
        result1 = self.segment1.take_analogsignal_by_unit()
        result21 = self.segment1.take_analogsignal_by_unit([self.unit1])
        result22 = self.segment1.take_analogsignal_by_unit([self.unit2])

        self.assertEqual(result1, [])

        for res, targ in zip(result21, self.chan1sig):
            assert_arrays_equal(res, targ)
            self.assertEqual(res.name, targ.name)

        for res, targ in zip(result22, self.chan2sig):
            assert_arrays_equal(res, targ)
            self.assertEqual(res.name, targ.name)

    def test_segment_take_analogsignal_by_channelindex(self):
        result1 = self.segment1.take_analogsignal_by_channelindex()
        result21 = self.segment1.take_analogsignal_by_channelindex([1])
        result22 = self.segment1.take_analogsignal_by_channelindex([2])

        self.assertEqual(result1, [])

        for res, targ in zip(result21, self.chan1sig):
            assert_arrays_equal(res, targ)
            self.assertEqual(res.name, targ.name)

        for res, targ in zip(result22, self.chan2sig):
            assert_arrays_equal(res, targ)
            self.assertEqual(res.name, targ.name)

    def test_segment_take_slice_of_analogsignalarray_by_unit(self):
        segment = self.segment1
        unit1 = self.unit1
        unit2 = self.unit2

        result1 = segment.take_slice_of_analogsignalarray_by_unit()
        result21 = segment.take_slice_of_analogsignalarray_by_unit([unit1])
        result22 = segment.take_slice_of_analogsignalarray_by_unit([unit2])

        self.assertEqual(result1, [])

        for res, targ in zip(result21, self.chan1sigarr1):
            assert_arrays_equal(res, targ)
            self.assertEqual(res.name, targ.name)

        for res, targ in zip(result22, self.chan2sigarr1):
            assert_arrays_equal(res, targ)
            self.assertEqual(res.name, targ.name)

    def test_segment_take_slice_of_analogsignalarray_by_channelindex(self):
        segment = self.segment1
        result1 = segment.take_slice_of_analogsignalarray_by_channelindex()
        result21 = segment.take_slice_of_analogsignalarray_by_channelindex([1])
        result22 = segment.take_slice_of_analogsignalarray_by_channelindex([2])

        self.assertEqual(result1, [])

        for res, targ in zip(result21, self.chan1sigarr1):
            assert_arrays_equal(res, targ)
            self.assertEqual(res.name, targ.name)

        for res, targ in zip(result22, self.chan2sigarr1):
            assert_arrays_equal(res, targ)
            self.assertEqual(res.name, targ.name)

    def test_segment_size(self):
        result1 = self.segment1.size()
        targ1 = {"epochs": 2,  "events": 2,  "analogsignals": 2,
                 "irregularlysampledsignals": 2, "spikes": 2,
                 "spiketrains": 2, "epocharrays": 2, "eventarrays": 2,
                 "analogsignalarrays": 2}

        self.assertEqual(result1, targ1)

    def test_segment_filter(self):
        result1 = self.segment1.filter()
        result2 = self.segment1.filter(name='analogsignal 1 1')
        result3 = self.segment1.filter(testattr=True)

        self.assertEqual(result1, [])

        self.assertEqual(len(result2), 1)
        assert_arrays_equal(result2[0], self.sig1[0])

        self.assertEqual(len(result3), 2)
        self.assertEqual(result3[0], self.epoch1[0])
        self.assertEqual(result3[1], self.event1[0])

    def test__children(self):
        blk = Block(name='block1')
        blk.segments = [self.segment1]
        blk.create_many_to_one_relationship()

        self.assertEqual(self.segment1._container_child_objects, ())
        self.assertEqual(self.segment1._data_child_objects,
                         ('AnalogSignal', 'AnalogSignalArray',
                          'Epoch', 'EpochArray',
                          'Event', 'EventArray',
                          'IrregularlySampledSignal',
                          'Spike', 'SpikeTrain'))
        self.assertEqual(self.segment1._single_parent_objects, ('Block',))
        self.assertEqual(self.segment1._multi_child_objects, ())
        self.assertEqual(self.segment1._multi_parent_objects, ())
        self.assertEqual(self.segment1._child_properties, ())

        self.assertEqual(self.segment1._single_child_objects,
                         ('AnalogSignal', 'AnalogSignalArray',
                          'Epoch', 'EpochArray',
                          'Event', 'EventArray',
                          'IrregularlySampledSignal',
                          'Spike', 'SpikeTrain'))

        self.assertEqual(self.segment1._container_child_containers, ())
        self.assertEqual(self.segment1._data_child_containers,
                         ('analogsignals', 'analogsignalarrays',
                          'epochs', 'epocharrays',
                          'events', 'eventarrays',
                          'irregularlysampledsignals',
                          'spikes', 'spiketrains'))
        self.assertEqual(self.segment1._single_child_containers,
                         ('analogsignals', 'analogsignalarrays',
                          'epochs', 'epocharrays',
                          'events', 'eventarrays',
                          'irregularlysampledsignals',
                          'spikes', 'spiketrains'))
        self.assertEqual(self.segment1._single_parent_containers, ('block',))
        self.assertEqual(self.segment1._multi_child_containers, ())
        self.assertEqual(self.segment1._multi_parent_containers, ())

        self.assertEqual(self.segment1._child_objects,
                         ('AnalogSignal', 'AnalogSignalArray',
                          'Epoch', 'EpochArray',
                          'Event', 'EventArray',
                          'IrregularlySampledSignal',
                          'Spike', 'SpikeTrain'))
        self.assertEqual(self.segment1._child_containers,
                         ('analogsignals', 'analogsignalarrays',
                          'epochs', 'epocharrays',
                          'events', 'eventarrays',
                          'irregularlysampledsignals',
                          'spikes', 'spiketrains'))
        self.assertEqual(self.segment1._parent_objects, ('Block',))
        self.assertEqual(self.segment1._parent_containers, ('block',))

        self.assertEqual(len(self.segment1.children),
                         (len(self.sig1) +
                          len(self.sigarr1) +
                          len(self.epoch1) +
                          len(self.epocharr1) +
                          len(self.event1) +
                          len(self.eventarr1) +
                          len(self.irsig1) +
                          len(self.spike1) +
                          len(self.train1)))
        self.assertEqual(self.segment1.children[0].name, self.signames1[0])
        self.assertEqual(self.segment1.children[1].name, self.signames1[1])
        self.assertEqual(self.segment1.children[2].name, self.sigarrnames1[0])
        self.assertEqual(self.segment1.children[3].name, self.sigarrnames1[1])
        self.assertEqual(self.segment1.children[4].name, self.epochnames1[0])
        self.assertEqual(self.segment1.children[5].name, self.epochnames1[1])
        self.assertEqual(self.segment1.children[6].name,
                         self.epocharrnames1[0])
        self.assertEqual(self.segment1.children[7].name,
                         self.epocharrnames1[1])
        self.assertEqual(self.segment1.children[8].name, self.eventnames1[0])
        self.assertEqual(self.segment1.children[9].name, self.eventnames1[1])
        self.assertEqual(self.segment1.children[10].name,
                         self.eventarrnames1[0])
        self.assertEqual(self.segment1.children[11].name,
                         self.eventarrnames1[1])
        self.assertEqual(self.segment1.children[12].name, self.irsignames1[0])
        self.assertEqual(self.segment1.children[13].name, self.irsignames1[1])
        self.assertEqual(self.segment1.children[14].name, self.spikenames1[0])
        self.assertEqual(self.segment1.children[15].name, self.spikenames1[1])
        self.assertEqual(self.segment1.children[16].name, self.trainnames1[0])
        self.assertEqual(self.segment1.children[17].name, self.trainnames1[1])
        self.assertEqual(len(self.segment1.parents), 1)
        self.assertEqual(self.segment1.parents[0].name, 'block1')

        self.segment1.create_many_to_one_relationship()
        self.segment1.create_many_to_many_relationship()
        self.segment1.create_relationship()
        assert_neo_object_is_compliant(self.segment1)