def test__recordingchannelgroup__cascade(self):
        obj_type = 'RecordingChannelGroup'
        cascade = True
        res = fake_neo(obj_type=obj_type, cascade=cascade)

        self.assertTrue(isinstance(res, RecordingChannelGroup))
        assert_neo_object_is_compliant(res)
        self.assertEqual(res.annotations, self.annotations)

        self.assertEqual(len(res.recordingchannels), 1)
        rchan = res.recordingchannels[0]
        self.assertEqual(rchan.annotations, self.annotations)

        self.assertEqual(len(res.units), 1)
        unit = res.units[0]
        self.assertEqual(unit.annotations, self.annotations)

        self.assertEqual(len(res.analogsignalarrays), 1)
        self.assertEqual(res.analogsignalarrays[0].annotations,
                         self.annotations)

        self.assertEqual(len(rchan.analogsignals), 1)
        self.assertEqual(len(rchan.irregularlysampledsignals), 1)
        self.assertEqual(rchan.analogsignals[0].annotations,
                         self.annotations)
        self.assertEqual(rchan.irregularlysampledsignals[0].annotations,
                         self.annotations)

        self.assertEqual(len(unit.spiketrains), 1)
        self.assertEqual(len(unit.spikes), 1)
        self.assertEqual(unit.spiketrains[0].annotations,
                         self.annotations)
        self.assertEqual(unit.spikes[0].annotations,
                         self.annotations)
    def test__analogsignalarray__cascade(self):
        obj_type = 'AnalogSignalArray'
        cascade = True
        res = fake_neo(obj_type=obj_type, cascade=cascade)

        self.assertTrue(isinstance(res, AnalogSignalArray))
        assert_neo_object_is_compliant(res)
        self.assertEqual(res.annotations, self.annotations)
    def test__event__cascade(self):
        obj_type = 'Event'
        cascade = True
        res = fake_neo(obj_type=obj_type, cascade=cascade)

        self.assertTrue(isinstance(res, Event))
        assert_neo_object_is_compliant(res)
        self.assertEqual(res.annotations, self.annotations)
    def test__spike__nocascade(self):
        obj_type = 'Spike'
        cascade = False
        res = fake_neo(obj_type=obj_type, cascade=cascade)

        self.assertTrue(isinstance(res, Spike))
        assert_neo_object_is_compliant(res)
        self.assertEqual(res.annotations, self.annotations)
    def test__epocharray__nocascade(self):
        obj_type = 'EpochArray'
        cascade = False
        res = fake_neo(obj_type=obj_type, cascade=cascade)

        self.assertTrue(isinstance(res, EpochArray))
        assert_neo_object_is_compliant(res)
        self.assertEqual(res.annotations, self.annotations)
    def test__irregularlysampledsignal__nocascade(self):
        obj_type = 'IrregularlySampledSignal'
        cascade = False
        res = fake_neo(obj_type=obj_type, cascade=cascade)

        self.assertTrue(isinstance(res, IrregularlySampledSignal))
        assert_neo_object_is_compliant(res)
        self.assertEqual(res.annotations, self.annotations)
예제 #7
0
 def test_property_change(self):
     """ Make sure all attributes are saved properly after the change,
     including quantities, units, types etc."""
     iom = NeoHdf5IO(filename=self.test_file)
     for obj_type in objectnames:
         obj = fake_neo(obj_type, cascade=False)
         iom.save(obj)
         self.assertTrue(hasattr(obj, "hdf5_path"))
         replica = iom.get(obj.hdf5_path, cascade=False)
         assert_objects_equivalent(obj, replica)
    def test__unit__nocascade(self):
        obj_type = 'Unit'
        cascade = False
        res = fake_neo(obj_type=obj_type, cascade=cascade)

        self.assertTrue(isinstance(res, Unit))
        assert_neo_object_is_compliant(res)
        self.assertEqual(res.annotations, self.annotations)

        self.assertEqual(len(res.spiketrains), 0)
        self.assertEqual(len(res.spikes), 0)
    def test__recordingchannel__nocascade(self):
        obj_type = 'RecordingChannel'
        cascade = False
        res = fake_neo(obj_type=obj_type, cascade=cascade)

        self.assertTrue(isinstance(res, RecordingChannel))
        assert_neo_object_is_compliant(res)
        self.assertEqual(res.annotations, self.annotations)

        self.assertEqual(len(res.analogsignals), 0)
        self.assertEqual(len(res.irregularlysampledsignals), 0)
    def test__block__nocascade(self):
        obj_type = 'Block'
        cascade = False
        res = fake_neo(obj_type=obj_type, cascade=cascade)

        self.assertTrue(isinstance(res, Block))
        assert_neo_object_is_compliant(res)
        self.assertEqual(res.annotations, self.annotations)

        self.assertEqual(len(res.segments), 0)
        self.assertEqual(len(res.recordingchannelgroups), 0)
    def test__recordingchannelgroup__nocascade(self):
        obj_type = 'RecordingChannelGroup'
        cascade = False
        res = fake_neo(obj_type=obj_type, cascade=cascade)

        self.assertTrue(isinstance(res, RecordingChannelGroup))
        assert_neo_object_is_compliant(res)
        self.assertEqual(res.annotations, self.annotations)

        self.assertEqual(len(res.recordingchannels), 0)
        self.assertEqual(len(res.units), 0)
        self.assertEqual(len(res.analogsignalarrays), 0)
예제 #12
0
 def test_attr_changes(self):
     """ gets an object, changes its attributes, saves it, then compares how
     good the changes were saved. """
     iom = NeoHdf5IO(filename=self.test_file)
     for obj_type in objectnames:
         obj = fake_neo(obj_type=obj_type, cascade=False)
         iom.save(obj)
         orig_obj = iom.get(obj.hdf5_path)
         for attr in obj._all_attrs:
             if hasattr(orig_obj, attr[0]):
                 setattr(obj, attr[0], get_fake_value(*attr))
         iom.save(orig_obj)
         test_obj = iom.get(orig_obj.hdf5_path)
         assert_objects_equivalent(orig_obj, test_obj)
예제 #13
0
 def test_create(self):
     """
     Create test file with signals, segments, blocks etc.
     """
     iom = NeoHdf5IO(filename=self.test_file)
     b1 = fake_neo()  # creating a structure
     iom.save(b1)  # saving
     # must be assigned after save
     self.assertTrue(hasattr(b1, "hdf5_path"))
     iom.close()
     iom.connect(filename=self.test_file)
     b2 = iom.get(b1.hdf5_path)  # new object
     assert_neo_object_is_compliant(b2)
     assert_same_sub_schema(b1, b2)
예제 #14
0
 def test_attr_changes(self):
     """ gets an object, changes its attributes, saves it, then compares how
     good the changes were saved. """
     iom = NeoHdf5IO(filename=self.test_file)
     for obj_type in class_by_name.keys():
         obj = fake_neo(obj_type=obj_type, cascade=False)
         iom.save(obj)
         orig_obj = iom.get(obj.hdf5_path)
         attrs = (classes_necessary_attributes[obj_type] +
                  classes_recommended_attributes[obj_type])
         for attr in attrs:
             if hasattr(orig_obj, attr[0]):
                 setattr(obj, attr[0], get_fake_value(*attr))
         iom.save(orig_obj)
         test_obj = iom.get(orig_obj.hdf5_path)
         assert_objects_equivalent(orig_obj, test_obj)
    def test__segment__nocascade(self):
        obj_type = 'Segment'
        cascade = False
        res = fake_neo(obj_type=obj_type, cascade=cascade)

        self.assertTrue(isinstance(res, Segment))
        assert_neo_object_is_compliant(res)
        self.assertEqual(res.annotations, self.annotations)

        self.assertEqual(len(res.analogsignalarrays), 0)
        self.assertEqual(len(res.analogsignals), 0)
        self.assertEqual(len(res.irregularlysampledsignals), 0)
        self.assertEqual(len(res.spiketrains), 0)
        self.assertEqual(len(res.spikes), 0)
        self.assertEqual(len(res.events), 0)
        self.assertEqual(len(res.epochs), 0)
        self.assertEqual(len(res.eventarrays), 0)
        self.assertEqual(len(res.epocharrays), 0)
예제 #16
0
 def test_relations(self):
     """
     make sure the change in relationships is saved properly in the file,
     including correct M2M, no redundancy etc. RC -> RCG not tested.
     """
     def assert_children(self, obj, replica):
         obj_type = obj.__name__
         self.assertEqual(md5(str(obj)).hexdigest(),
                          md5(str(replica)).hexdigest())
         for container in getattr(obj, '_child_containers', []):
             ch1 = getattr(obj, container)
             ch2 = getattr(replica, container)
             self.assertEqual(len(ch1), len(ch2))
             for i, v in enumerate(ch1):
                 self.assert_children(ch1[i], ch2[i])
     iom = NeoHdf5IO(filename=self.test_file)
     for obj_type in objectnames:
         obj = fake_neo(obj_type, cascade=True)
         iom.save(obj)
         self.assertTrue(hasattr(obj, "hdf5_path"))
         replica = iom.get(obj.hdf5_path, cascade=True)
         self.assert_children(obj, replica)
예제 #17
0
def generate_diagram(filename, rect_pos, rect_width, figsize):
    rw = rect_width

    fig = pyplot.figure(figsize=figsize)
    ax = fig.add_axes([0, 0, 1, 1])

    all_h = {}
    objs = {}
    for name in rect_pos:
        objs[name] = fake_neo(name)
        all_h[name] = get_rect_height(name, objs[name])

    # draw connections
    color = ['c', 'm', 'y']
    alpha = [1., 1., 0.3]
    for name, pos in rect_pos.items():
        obj = objs[name]
        relationships = [getattr(obj, '_single_child_objects', []),
                         getattr(obj, '_multi_child_objects', []),
                         getattr(obj, '_child_properties', [])]

        for r in range(3):
            for ch_name in relationships[r]:
                x1, y1 = calc_coordinates(rect_pos[ch_name], all_h[ch_name])
                x2, y2 = calc_coordinates(pos, all_h[name])

                if r in [0, 2]:
                    x2 += rect_width
                    connectionstyle = "arc3,rad=-0.2"
                elif y2 >= y1:
                    connectionstyle = "arc3,rad=0.7"
                else:
                    connectionstyle = "arc3,rad=-0.7"

                annotate(ax=ax, coord1=(x1, y1), coord2=(x2, y2),
                         connectionstyle=connectionstyle,
                         color=color[r], alpha=alpha[r])

    # draw boxes
    for name, pos in rect_pos.items():
        htotal = all_h[name]
        obj = objs[name]
        allrelationship = (getattr(obj, '_child_containers', []) +
                           getattr(obj, '_multi_parent_containers', []))

        rect = Rectangle(pos, rect_width, htotal,
                         facecolor='w', edgecolor='k', linewidth=2.)
        ax.add_patch(rect)

        # title green
        pos2 = pos[0], pos[1]+htotal - line_heigth*1.5
        rect = Rectangle(pos2, rect_width, line_heigth*1.5,
                         facecolor='g', edgecolor='k', alpha=.5, linewidth=2.)
        ax.add_patch(rect)

        # single relationship
        relationship = getattr(obj, '_single_child_objects', [])
        pos2 = pos[1] + htotal - line_heigth*(1.5+len(relationship))
        rect_height = len(relationship)*line_heigth

        rect = Rectangle((pos[0], pos2), rect_width, rect_height,
                         facecolor='c', edgecolor='k', alpha=.5)
        ax.add_patch(rect)

        # multi relationship
        relationship = (getattr(obj, '_multi_child_objects', []) +
                        getattr(obj, '_multi_parent_containers', []))
        pos2 = (pos[1]+htotal - line_heigth*(1.5+len(relationship)) -
                rect_height)
        rect_height = len(relationship)*line_heigth

        rect = Rectangle((pos[0], pos2), rect_width, rect_height,
                         facecolor='m', edgecolor='k', alpha=.5)
        ax.add_patch(rect)

        # necessary attr
        pos2 = (pos[1]+htotal -
                line_heigth*(1.5+len(allrelationship) +
                             len(obj._necessary_attrs)))
        rect = Rectangle((pos[0], pos2), rect_width,
                         line_heigth*len(obj._necessary_attrs),
                         facecolor='r', edgecolor='k', alpha=.5)
        ax.add_patch(rect)

        # name
        if hasattr(obj, '_quantity_attr'):
            post = '* '
        else:
            post = ''
        ax.text(pos[0]+rect_width/2., pos[1]+htotal - line_heigth*1.5/2.,
                name+post,
                horizontalalignment='center', verticalalignment='center',
                fontsize=fontsize+2,
                fontproperties=FontProperties(weight='bold'),
                )

        #relationship
        for i, relat in enumerate(allrelationship):
            ax.text(pos[0]+left_text_shift, pos[1]+htotal - line_heigth*(i+2),
                    relat+': list',
                    horizontalalignment='left', verticalalignment='center',
                    fontsize=fontsize,
                    )
        # attributes
        for i, attr in enumerate(obj._all_attrs):
            attrname, attrtype = attr[0], attr[1]
            t1 = attrname
            if (hasattr(obj, '_quantity_attr') and
                    obj._quantity_attr == attrname):
                t1 = attrname+'(object itself)'
            else:
                t1 = attrname

            if attrtype == pq.Quantity:
                if attr[2] == 0:
                    t2 = 'Quantity scalar'
                else:
                    t2 = 'Quantity %dD' % attr[2]
            elif attrtype == np.ndarray:
                t2 = "np.ndarray %dD dt='%s'" % (attr[2], attr[3].kind)
            elif attrtype == datetime:
                t2 = 'datetime'
            else:
                t2 = attrtype.__name__

            t = t1+' :  '+t2
            ax.text(pos[0]+left_text_shift,
                    pos[1]+htotal - line_heigth*(i+len(allrelationship)+2),
                    t,
                    horizontalalignment='left', verticalalignment='center',
                    fontsize=fontsize,
                    )

    xlim, ylim = figsize
    ax.set_xlim(0, xlim)
    ax.set_ylim(0, ylim)

    ax.set_xticks([])
    ax.set_yticks([])
    fig.savefig(filename, dpi=dpi)
예제 #18
0
def generate_diagram(filename, rect_pos, rect_width, figsize):
    rw = rect_width

    fig = pyplot.figure(figsize=figsize)
    ax = fig.add_axes([0, 0, 1, 1])

    all_h = {}
    objs = {}
    for name in rect_pos:
        objs[name] = fake_neo(name)
        all_h[name] = get_rect_height(name, objs[name])

    # draw connections
    color = ['c', 'm', 'y']
    alpha = [1., 1., 0.3]
    for name, pos in rect_pos.items():
        obj = objs[name]
        relationships = [
            getattr(obj, '_single_child_objects', []),
            getattr(obj, '_multi_child_objects', []),
            getattr(obj, '_child_properties', [])
        ]

        for r in range(3):
            for ch_name in relationships[r]:
                x1, y1 = calc_coordinates(rect_pos[ch_name], all_h[ch_name])
                x2, y2 = calc_coordinates(pos, all_h[name])

                if r in [0, 2]:
                    x2 += rect_width
                    connectionstyle = "arc3,rad=-0.2"
                elif y2 >= y1:
                    connectionstyle = "arc3,rad=0.7"
                else:
                    connectionstyle = "arc3,rad=-0.7"

                annotate(ax=ax,
                         coord1=(x1, y1),
                         coord2=(x2, y2),
                         connectionstyle=connectionstyle,
                         color=color[r],
                         alpha=alpha[r])

    # draw boxes
    for name, pos in rect_pos.items():
        htotal = all_h[name]
        obj = objs[name]
        allrelationship = (getattr(obj, '_child_containers', []) +
                           getattr(obj, '_multi_parent_containers', []))

        rect = Rectangle(pos,
                         rect_width,
                         htotal,
                         facecolor='w',
                         edgecolor='k',
                         linewidth=2.)
        ax.add_patch(rect)

        # title green
        pos2 = pos[0], pos[1] + htotal - line_heigth * 1.5
        rect = Rectangle(pos2,
                         rect_width,
                         line_heigth * 1.5,
                         facecolor='g',
                         edgecolor='k',
                         alpha=.5,
                         linewidth=2.)
        ax.add_patch(rect)

        # single relationship
        relationship = getattr(obj, '_single_child_objects', [])
        pos2 = pos[1] + htotal - line_heigth * (1.5 + len(relationship))
        rect_height = len(relationship) * line_heigth

        rect = Rectangle((pos[0], pos2),
                         rect_width,
                         rect_height,
                         facecolor='c',
                         edgecolor='k',
                         alpha=.5)
        ax.add_patch(rect)

        # multi relationship
        relationship = (getattr(obj, '_multi_child_objects', []) +
                        getattr(obj, '_multi_parent_containers', []))
        pos2 = (pos[1] + htotal - line_heigth * (1.5 + len(relationship)) -
                rect_height)
        rect_height = len(relationship) * line_heigth

        rect = Rectangle((pos[0], pos2),
                         rect_width,
                         rect_height,
                         facecolor='m',
                         edgecolor='k',
                         alpha=.5)
        ax.add_patch(rect)

        # necessary attr
        pos2 = (pos[1] + htotal - line_heigth *
                (1.5 + len(allrelationship) + len(obj._necessary_attrs)))
        rect = Rectangle((pos[0], pos2),
                         rect_width,
                         line_heigth * len(obj._necessary_attrs),
                         facecolor='r',
                         edgecolor='k',
                         alpha=.5)
        ax.add_patch(rect)

        # name
        if hasattr(obj, '_quantity_attr'):
            post = '* '
        else:
            post = ''
        ax.text(
            pos[0] + rect_width / 2.,
            pos[1] + htotal - line_heigth * 1.5 / 2.,
            name + post,
            horizontalalignment='center',
            verticalalignment='center',
            fontsize=fontsize + 2,
            fontproperties=FontProperties(weight='bold'),
        )

        #relationship
        for i, relat in enumerate(allrelationship):
            ax.text(
                pos[0] + left_text_shift,
                pos[1] + htotal - line_heigth * (i + 2),
                relat + ': list',
                horizontalalignment='left',
                verticalalignment='center',
                fontsize=fontsize,
            )
        # attributes
        for i, attr in enumerate(obj._all_attrs):
            attrname, attrtype = attr[0], attr[1]
            t1 = attrname
            if (hasattr(obj, '_quantity_attr')
                    and obj._quantity_attr == attrname):
                t1 = attrname + '(object itself)'
            else:
                t1 = attrname

            if attrtype == pq.Quantity:
                if attr[2] == 0:
                    t2 = 'Quantity scalar'
                else:
                    t2 = 'Quantity %dD' % attr[2]
            elif attrtype == np.ndarray:
                t2 = "np.ndarray %dD dt='%s'" % (attr[2], attr[3].kind)
            elif attrtype == datetime:
                t2 = 'datetime'
            else:
                t2 = attrtype.__name__

            t = t1 + ' :  ' + t2
            ax.text(
                pos[0] + left_text_shift,
                pos[1] + htotal - line_heigth * (i + len(allrelationship) + 2),
                t,
                horizontalalignment='left',
                verticalalignment='center',
                fontsize=fontsize,
            )

    xlim, ylim = figsize
    ax.set_xlim(0, xlim)
    ax.set_ylim(0, ylim)

    ax.set_xticks([])
    ax.set_yticks([])
    fig.savefig(filename, dpi=dpi)
    def test__block__cascade(self):
        obj_type = 'Block'
        cascade = True
        res = fake_neo(obj_type=obj_type, cascade=cascade)

        self.assertTrue(isinstance(res, Block))
        assert_neo_object_is_compliant(res)
        self.assertEqual(res.annotations, self.annotations)

        self.assertEqual(len(res.segments), 1)
        seg = res.segments[0]
        self.assertEqual(seg.annotations, self.annotations)

        self.assertEqual(len(res.recordingchannelgroups), 1)
        rcg = res.recordingchannelgroups[0]
        self.assertEqual(rcg.annotations, self.annotations)

        self.assertEqual(len(seg.analogsignalarrays), 1)
        self.assertEqual(len(seg.analogsignals), 1)
        self.assertEqual(len(seg.irregularlysampledsignals), 1)
        self.assertEqual(len(seg.spiketrains), 1)
        self.assertEqual(len(seg.spikes), 1)
        self.assertEqual(len(seg.events), 1)
        self.assertEqual(len(seg.epochs), 1)
        self.assertEqual(len(seg.eventarrays), 1)
        self.assertEqual(len(seg.epocharrays), 1)
        self.assertEqual(seg.analogsignalarrays[0].annotations,
                         self.annotations)
        self.assertEqual(seg.analogsignals[0].annotations,
                         self.annotations)
        self.assertEqual(seg.irregularlysampledsignals[0].annotations,
                         self.annotations)
        self.assertEqual(seg.spiketrains[0].annotations,
                         self.annotations)
        self.assertEqual(seg.spikes[0].annotations,
                         self.annotations)
        self.assertEqual(seg.events[0].annotations,
                         self.annotations)
        self.assertEqual(seg.epochs[0].annotations,
                         self.annotations)
        self.assertEqual(seg.eventarrays[0].annotations,
                         self.annotations)
        self.assertEqual(seg.epocharrays[0].annotations,
                         self.annotations)

        self.assertEqual(len(rcg.recordingchannels), 1)
        rchan = rcg.recordingchannels[0]
        self.assertEqual(rchan.annotations, self.annotations)

        self.assertEqual(len(rcg.units), 1)
        unit = rcg.units[0]
        self.assertEqual(unit.annotations, self.annotations)

        self.assertEqual(len(rcg.analogsignalarrays), 1)
        self.assertEqual(rcg.analogsignalarrays[0].annotations,
                         self.annotations)

        self.assertEqual(len(rchan.analogsignals), 1)
        self.assertEqual(len(rchan.irregularlysampledsignals), 1)
        self.assertEqual(rchan.analogsignals[0].annotations,
                         self.annotations)
        self.assertEqual(rchan.irregularlysampledsignals[0].annotations,
                         self.annotations)

        self.assertEqual(len(unit.spiketrains), 1)
        self.assertEqual(len(unit.spikes), 1)
        self.assertEqual(unit.spiketrains[0].annotations,
                         self.annotations)
        self.assertEqual(unit.spikes[0].annotations,
                         self.annotations)