def test__recordingchannelgroup__cascade(self): obj_type = 'RecordingChannelGroup' cascade = True res = fake_neo(obj_type=obj_type, cascade=cascade) self.assertTrue(isinstance(res, RecordingChannelGroup)) assert_neo_object_is_compliant(res) self.assertEqual(res.annotations, self.annotations) self.assertEqual(len(res.recordingchannels), 1) rchan = res.recordingchannels[0] self.assertEqual(rchan.annotations, self.annotations) self.assertEqual(len(res.units), 1) unit = res.units[0] self.assertEqual(unit.annotations, self.annotations) self.assertEqual(len(res.analogsignalarrays), 1) self.assertEqual(res.analogsignalarrays[0].annotations, self.annotations) self.assertEqual(len(rchan.analogsignals), 1) self.assertEqual(len(rchan.irregularlysampledsignals), 1) self.assertEqual(rchan.analogsignals[0].annotations, self.annotations) self.assertEqual(rchan.irregularlysampledsignals[0].annotations, self.annotations) self.assertEqual(len(unit.spiketrains), 1) self.assertEqual(len(unit.spikes), 1) self.assertEqual(unit.spiketrains[0].annotations, self.annotations) self.assertEqual(unit.spikes[0].annotations, self.annotations)
def test__analogsignalarray__cascade(self): obj_type = 'AnalogSignalArray' cascade = True res = fake_neo(obj_type=obj_type, cascade=cascade) self.assertTrue(isinstance(res, AnalogSignalArray)) assert_neo_object_is_compliant(res) self.assertEqual(res.annotations, self.annotations)
def test__event__cascade(self): obj_type = 'Event' cascade = True res = fake_neo(obj_type=obj_type, cascade=cascade) self.assertTrue(isinstance(res, Event)) assert_neo_object_is_compliant(res) self.assertEqual(res.annotations, self.annotations)
def test__spike__nocascade(self): obj_type = 'Spike' cascade = False res = fake_neo(obj_type=obj_type, cascade=cascade) self.assertTrue(isinstance(res, Spike)) assert_neo_object_is_compliant(res) self.assertEqual(res.annotations, self.annotations)
def test__epocharray__nocascade(self): obj_type = 'EpochArray' cascade = False res = fake_neo(obj_type=obj_type, cascade=cascade) self.assertTrue(isinstance(res, EpochArray)) assert_neo_object_is_compliant(res) self.assertEqual(res.annotations, self.annotations)
def test__irregularlysampledsignal__nocascade(self): obj_type = 'IrregularlySampledSignal' cascade = False res = fake_neo(obj_type=obj_type, cascade=cascade) self.assertTrue(isinstance(res, IrregularlySampledSignal)) assert_neo_object_is_compliant(res) self.assertEqual(res.annotations, self.annotations)
def test_property_change(self): """ Make sure all attributes are saved properly after the change, including quantities, units, types etc.""" iom = NeoHdf5IO(filename=self.test_file) for obj_type in objectnames: obj = fake_neo(obj_type, cascade=False) iom.save(obj) self.assertTrue(hasattr(obj, "hdf5_path")) replica = iom.get(obj.hdf5_path, cascade=False) assert_objects_equivalent(obj, replica)
def test__unit__nocascade(self): obj_type = 'Unit' cascade = False res = fake_neo(obj_type=obj_type, cascade=cascade) self.assertTrue(isinstance(res, Unit)) assert_neo_object_is_compliant(res) self.assertEqual(res.annotations, self.annotations) self.assertEqual(len(res.spiketrains), 0) self.assertEqual(len(res.spikes), 0)
def test__recordingchannel__nocascade(self): obj_type = 'RecordingChannel' cascade = False res = fake_neo(obj_type=obj_type, cascade=cascade) self.assertTrue(isinstance(res, RecordingChannel)) assert_neo_object_is_compliant(res) self.assertEqual(res.annotations, self.annotations) self.assertEqual(len(res.analogsignals), 0) self.assertEqual(len(res.irregularlysampledsignals), 0)
def test__block__nocascade(self): obj_type = 'Block' cascade = False res = fake_neo(obj_type=obj_type, cascade=cascade) self.assertTrue(isinstance(res, Block)) assert_neo_object_is_compliant(res) self.assertEqual(res.annotations, self.annotations) self.assertEqual(len(res.segments), 0) self.assertEqual(len(res.recordingchannelgroups), 0)
def test__recordingchannelgroup__nocascade(self): obj_type = 'RecordingChannelGroup' cascade = False res = fake_neo(obj_type=obj_type, cascade=cascade) self.assertTrue(isinstance(res, RecordingChannelGroup)) assert_neo_object_is_compliant(res) self.assertEqual(res.annotations, self.annotations) self.assertEqual(len(res.recordingchannels), 0) self.assertEqual(len(res.units), 0) self.assertEqual(len(res.analogsignalarrays), 0)
def test_attr_changes(self): """ gets an object, changes its attributes, saves it, then compares how good the changes were saved. """ iom = NeoHdf5IO(filename=self.test_file) for obj_type in objectnames: obj = fake_neo(obj_type=obj_type, cascade=False) iom.save(obj) orig_obj = iom.get(obj.hdf5_path) for attr in obj._all_attrs: if hasattr(orig_obj, attr[0]): setattr(obj, attr[0], get_fake_value(*attr)) iom.save(orig_obj) test_obj = iom.get(orig_obj.hdf5_path) assert_objects_equivalent(orig_obj, test_obj)
def test_create(self): """ Create test file with signals, segments, blocks etc. """ iom = NeoHdf5IO(filename=self.test_file) b1 = fake_neo() # creating a structure iom.save(b1) # saving # must be assigned after save self.assertTrue(hasattr(b1, "hdf5_path")) iom.close() iom.connect(filename=self.test_file) b2 = iom.get(b1.hdf5_path) # new object assert_neo_object_is_compliant(b2) assert_same_sub_schema(b1, b2)
def test_attr_changes(self): """ gets an object, changes its attributes, saves it, then compares how good the changes were saved. """ iom = NeoHdf5IO(filename=self.test_file) for obj_type in class_by_name.keys(): obj = fake_neo(obj_type=obj_type, cascade=False) iom.save(obj) orig_obj = iom.get(obj.hdf5_path) attrs = (classes_necessary_attributes[obj_type] + classes_recommended_attributes[obj_type]) for attr in attrs: if hasattr(orig_obj, attr[0]): setattr(obj, attr[0], get_fake_value(*attr)) iom.save(orig_obj) test_obj = iom.get(orig_obj.hdf5_path) assert_objects_equivalent(orig_obj, test_obj)
def test__segment__nocascade(self): obj_type = 'Segment' cascade = False res = fake_neo(obj_type=obj_type, cascade=cascade) self.assertTrue(isinstance(res, Segment)) assert_neo_object_is_compliant(res) self.assertEqual(res.annotations, self.annotations) self.assertEqual(len(res.analogsignalarrays), 0) self.assertEqual(len(res.analogsignals), 0) self.assertEqual(len(res.irregularlysampledsignals), 0) self.assertEqual(len(res.spiketrains), 0) self.assertEqual(len(res.spikes), 0) self.assertEqual(len(res.events), 0) self.assertEqual(len(res.epochs), 0) self.assertEqual(len(res.eventarrays), 0) self.assertEqual(len(res.epocharrays), 0)
def test_relations(self): """ make sure the change in relationships is saved properly in the file, including correct M2M, no redundancy etc. RC -> RCG not tested. """ def assert_children(self, obj, replica): obj_type = obj.__name__ self.assertEqual(md5(str(obj)).hexdigest(), md5(str(replica)).hexdigest()) for container in getattr(obj, '_child_containers', []): ch1 = getattr(obj, container) ch2 = getattr(replica, container) self.assertEqual(len(ch1), len(ch2)) for i, v in enumerate(ch1): self.assert_children(ch1[i], ch2[i]) iom = NeoHdf5IO(filename=self.test_file) for obj_type in objectnames: obj = fake_neo(obj_type, cascade=True) iom.save(obj) self.assertTrue(hasattr(obj, "hdf5_path")) replica = iom.get(obj.hdf5_path, cascade=True) self.assert_children(obj, replica)
def generate_diagram(filename, rect_pos, rect_width, figsize): rw = rect_width fig = pyplot.figure(figsize=figsize) ax = fig.add_axes([0, 0, 1, 1]) all_h = {} objs = {} for name in rect_pos: objs[name] = fake_neo(name) all_h[name] = get_rect_height(name, objs[name]) # draw connections color = ['c', 'm', 'y'] alpha = [1., 1., 0.3] for name, pos in rect_pos.items(): obj = objs[name] relationships = [getattr(obj, '_single_child_objects', []), getattr(obj, '_multi_child_objects', []), getattr(obj, '_child_properties', [])] for r in range(3): for ch_name in relationships[r]: x1, y1 = calc_coordinates(rect_pos[ch_name], all_h[ch_name]) x2, y2 = calc_coordinates(pos, all_h[name]) if r in [0, 2]: x2 += rect_width connectionstyle = "arc3,rad=-0.2" elif y2 >= y1: connectionstyle = "arc3,rad=0.7" else: connectionstyle = "arc3,rad=-0.7" annotate(ax=ax, coord1=(x1, y1), coord2=(x2, y2), connectionstyle=connectionstyle, color=color[r], alpha=alpha[r]) # draw boxes for name, pos in rect_pos.items(): htotal = all_h[name] obj = objs[name] allrelationship = (getattr(obj, '_child_containers', []) + getattr(obj, '_multi_parent_containers', [])) rect = Rectangle(pos, rect_width, htotal, facecolor='w', edgecolor='k', linewidth=2.) ax.add_patch(rect) # title green pos2 = pos[0], pos[1]+htotal - line_heigth*1.5 rect = Rectangle(pos2, rect_width, line_heigth*1.5, facecolor='g', edgecolor='k', alpha=.5, linewidth=2.) ax.add_patch(rect) # single relationship relationship = getattr(obj, '_single_child_objects', []) pos2 = pos[1] + htotal - line_heigth*(1.5+len(relationship)) rect_height = len(relationship)*line_heigth rect = Rectangle((pos[0], pos2), rect_width, rect_height, facecolor='c', edgecolor='k', alpha=.5) ax.add_patch(rect) # multi relationship relationship = (getattr(obj, '_multi_child_objects', []) + getattr(obj, '_multi_parent_containers', [])) pos2 = (pos[1]+htotal - line_heigth*(1.5+len(relationship)) - rect_height) rect_height = len(relationship)*line_heigth rect = Rectangle((pos[0], pos2), rect_width, rect_height, facecolor='m', edgecolor='k', alpha=.5) ax.add_patch(rect) # necessary attr pos2 = (pos[1]+htotal - line_heigth*(1.5+len(allrelationship) + len(obj._necessary_attrs))) rect = Rectangle((pos[0], pos2), rect_width, line_heigth*len(obj._necessary_attrs), facecolor='r', edgecolor='k', alpha=.5) ax.add_patch(rect) # name if hasattr(obj, '_quantity_attr'): post = '* ' else: post = '' ax.text(pos[0]+rect_width/2., pos[1]+htotal - line_heigth*1.5/2., name+post, horizontalalignment='center', verticalalignment='center', fontsize=fontsize+2, fontproperties=FontProperties(weight='bold'), ) #relationship for i, relat in enumerate(allrelationship): ax.text(pos[0]+left_text_shift, pos[1]+htotal - line_heigth*(i+2), relat+': list', horizontalalignment='left', verticalalignment='center', fontsize=fontsize, ) # attributes for i, attr in enumerate(obj._all_attrs): attrname, attrtype = attr[0], attr[1] t1 = attrname if (hasattr(obj, '_quantity_attr') and obj._quantity_attr == attrname): t1 = attrname+'(object itself)' else: t1 = attrname if attrtype == pq.Quantity: if attr[2] == 0: t2 = 'Quantity scalar' else: t2 = 'Quantity %dD' % attr[2] elif attrtype == np.ndarray: t2 = "np.ndarray %dD dt='%s'" % (attr[2], attr[3].kind) elif attrtype == datetime: t2 = 'datetime' else: t2 = attrtype.__name__ t = t1+' : '+t2 ax.text(pos[0]+left_text_shift, pos[1]+htotal - line_heigth*(i+len(allrelationship)+2), t, horizontalalignment='left', verticalalignment='center', fontsize=fontsize, ) xlim, ylim = figsize ax.set_xlim(0, xlim) ax.set_ylim(0, ylim) ax.set_xticks([]) ax.set_yticks([]) fig.savefig(filename, dpi=dpi)
def generate_diagram(filename, rect_pos, rect_width, figsize): rw = rect_width fig = pyplot.figure(figsize=figsize) ax = fig.add_axes([0, 0, 1, 1]) all_h = {} objs = {} for name in rect_pos: objs[name] = fake_neo(name) all_h[name] = get_rect_height(name, objs[name]) # draw connections color = ['c', 'm', 'y'] alpha = [1., 1., 0.3] for name, pos in rect_pos.items(): obj = objs[name] relationships = [ getattr(obj, '_single_child_objects', []), getattr(obj, '_multi_child_objects', []), getattr(obj, '_child_properties', []) ] for r in range(3): for ch_name in relationships[r]: x1, y1 = calc_coordinates(rect_pos[ch_name], all_h[ch_name]) x2, y2 = calc_coordinates(pos, all_h[name]) if r in [0, 2]: x2 += rect_width connectionstyle = "arc3,rad=-0.2" elif y2 >= y1: connectionstyle = "arc3,rad=0.7" else: connectionstyle = "arc3,rad=-0.7" annotate(ax=ax, coord1=(x1, y1), coord2=(x2, y2), connectionstyle=connectionstyle, color=color[r], alpha=alpha[r]) # draw boxes for name, pos in rect_pos.items(): htotal = all_h[name] obj = objs[name] allrelationship = (getattr(obj, '_child_containers', []) + getattr(obj, '_multi_parent_containers', [])) rect = Rectangle(pos, rect_width, htotal, facecolor='w', edgecolor='k', linewidth=2.) ax.add_patch(rect) # title green pos2 = pos[0], pos[1] + htotal - line_heigth * 1.5 rect = Rectangle(pos2, rect_width, line_heigth * 1.5, facecolor='g', edgecolor='k', alpha=.5, linewidth=2.) ax.add_patch(rect) # single relationship relationship = getattr(obj, '_single_child_objects', []) pos2 = pos[1] + htotal - line_heigth * (1.5 + len(relationship)) rect_height = len(relationship) * line_heigth rect = Rectangle((pos[0], pos2), rect_width, rect_height, facecolor='c', edgecolor='k', alpha=.5) ax.add_patch(rect) # multi relationship relationship = (getattr(obj, '_multi_child_objects', []) + getattr(obj, '_multi_parent_containers', [])) pos2 = (pos[1] + htotal - line_heigth * (1.5 + len(relationship)) - rect_height) rect_height = len(relationship) * line_heigth rect = Rectangle((pos[0], pos2), rect_width, rect_height, facecolor='m', edgecolor='k', alpha=.5) ax.add_patch(rect) # necessary attr pos2 = (pos[1] + htotal - line_heigth * (1.5 + len(allrelationship) + len(obj._necessary_attrs))) rect = Rectangle((pos[0], pos2), rect_width, line_heigth * len(obj._necessary_attrs), facecolor='r', edgecolor='k', alpha=.5) ax.add_patch(rect) # name if hasattr(obj, '_quantity_attr'): post = '* ' else: post = '' ax.text( pos[0] + rect_width / 2., pos[1] + htotal - line_heigth * 1.5 / 2., name + post, horizontalalignment='center', verticalalignment='center', fontsize=fontsize + 2, fontproperties=FontProperties(weight='bold'), ) #relationship for i, relat in enumerate(allrelationship): ax.text( pos[0] + left_text_shift, pos[1] + htotal - line_heigth * (i + 2), relat + ': list', horizontalalignment='left', verticalalignment='center', fontsize=fontsize, ) # attributes for i, attr in enumerate(obj._all_attrs): attrname, attrtype = attr[0], attr[1] t1 = attrname if (hasattr(obj, '_quantity_attr') and obj._quantity_attr == attrname): t1 = attrname + '(object itself)' else: t1 = attrname if attrtype == pq.Quantity: if attr[2] == 0: t2 = 'Quantity scalar' else: t2 = 'Quantity %dD' % attr[2] elif attrtype == np.ndarray: t2 = "np.ndarray %dD dt='%s'" % (attr[2], attr[3].kind) elif attrtype == datetime: t2 = 'datetime' else: t2 = attrtype.__name__ t = t1 + ' : ' + t2 ax.text( pos[0] + left_text_shift, pos[1] + htotal - line_heigth * (i + len(allrelationship) + 2), t, horizontalalignment='left', verticalalignment='center', fontsize=fontsize, ) xlim, ylim = figsize ax.set_xlim(0, xlim) ax.set_ylim(0, ylim) ax.set_xticks([]) ax.set_yticks([]) fig.savefig(filename, dpi=dpi)
def test__block__cascade(self): obj_type = 'Block' cascade = True res = fake_neo(obj_type=obj_type, cascade=cascade) self.assertTrue(isinstance(res, Block)) assert_neo_object_is_compliant(res) self.assertEqual(res.annotations, self.annotations) self.assertEqual(len(res.segments), 1) seg = res.segments[0] self.assertEqual(seg.annotations, self.annotations) self.assertEqual(len(res.recordingchannelgroups), 1) rcg = res.recordingchannelgroups[0] self.assertEqual(rcg.annotations, self.annotations) self.assertEqual(len(seg.analogsignalarrays), 1) self.assertEqual(len(seg.analogsignals), 1) self.assertEqual(len(seg.irregularlysampledsignals), 1) self.assertEqual(len(seg.spiketrains), 1) self.assertEqual(len(seg.spikes), 1) self.assertEqual(len(seg.events), 1) self.assertEqual(len(seg.epochs), 1) self.assertEqual(len(seg.eventarrays), 1) self.assertEqual(len(seg.epocharrays), 1) self.assertEqual(seg.analogsignalarrays[0].annotations, self.annotations) self.assertEqual(seg.analogsignals[0].annotations, self.annotations) self.assertEqual(seg.irregularlysampledsignals[0].annotations, self.annotations) self.assertEqual(seg.spiketrains[0].annotations, self.annotations) self.assertEqual(seg.spikes[0].annotations, self.annotations) self.assertEqual(seg.events[0].annotations, self.annotations) self.assertEqual(seg.epochs[0].annotations, self.annotations) self.assertEqual(seg.eventarrays[0].annotations, self.annotations) self.assertEqual(seg.epocharrays[0].annotations, self.annotations) self.assertEqual(len(rcg.recordingchannels), 1) rchan = rcg.recordingchannels[0] self.assertEqual(rchan.annotations, self.annotations) self.assertEqual(len(rcg.units), 1) unit = rcg.units[0] self.assertEqual(unit.annotations, self.annotations) self.assertEqual(len(rcg.analogsignalarrays), 1) self.assertEqual(rcg.analogsignalarrays[0].annotations, self.annotations) self.assertEqual(len(rchan.analogsignals), 1) self.assertEqual(len(rchan.irregularlysampledsignals), 1) self.assertEqual(rchan.analogsignals[0].annotations, self.annotations) self.assertEqual(rchan.irregularlysampledsignals[0].annotations, self.annotations) self.assertEqual(len(unit.spiketrains), 1) self.assertEqual(len(unit.spikes), 1) self.assertEqual(unit.spiketrains[0].annotations, self.annotations) self.assertEqual(unit.spikes[0].annotations, self.annotations)