def __init__(self, parent=None, parent_sp=None, visible=True, cmap='gray'):
        """Init."""
        CrossSectionsSplit.__init__(self, parent_sp)
        self._visible_cs = visible
        self._cmap_cs = cmap
        #######################################################################
        #                           TRANFORMATIONS
        #######################################################################
        # Translations :
        self._tr_coron = vist.STTransform()
        self._tr_sagit = vist.STTransform()
        self._tr_axial = vist.STTransform()
        # Rotations :
        rot_m90x = vist.MatrixTransform(rotate(-90, [1, 0, 0]))
        rot_180x = vist.MatrixTransform(rotate(180, [1, 0, 0]))
        rot_90y = vist.MatrixTransform(rotate(90, [0, 1, 0]))
        rot_m90y = vist.MatrixTransform(rotate(-90, [0, 1, 0]))
        rot_m180y = vist.MatrixTransform(rotate(-180, [0, 1, 0]))
        rot_90z = vist.MatrixTransform(rotate(90, [0, 0, 1]))
        # Tranformations :
        tf_sagit = [self._tr_sagit, rot_90z, rot_m90y, rot_m90x]
        tf_coron = [self._tr_coron, rot_90z, rot_180x, rot_90y]
        tf_axial = [self._tr_axial, rot_m180y, rot_90z]

        #######################################################################
        #                            ELEMENTS
        #######################################################################
        # Create a root node :
        self._node_cs = scene.Node(name='Cross-Sections')
        self._node_cs.parent = parent
        self._node_cs.visible = visible
        # Axial :
        self.axial = ImageSection(parent=self._node_cs, name='Axial')
        self.axial.transform = vist.ChainTransform(tf_axial)
        # Coronal :
        self.coron = ImageSection(parent=self._node_cs, name='Coronal')
        self.coron.transform = vist.ChainTransform(tf_coron)
        # Sagittal :
        self.sagit = ImageSection(parent=self._node_cs, name='Sagittal')
        self.sagit.transform = vist.ChainTransform(tf_sagit)
        # Set GL state :
        kwargs = {
            'depth_test': True,
            'cull_face': False,
            'blend': False,
            'blend_func': ('src_alpha', 'one_minus_src_alpha')
        }
        self.sagit.set_gl_state('translucent', **kwargs)
        self.coron.set_gl_state('translucent', **kwargs)
        self.axial.set_gl_state('translucent', **kwargs)
Beispiel #2
0
 def _define_transformation(self):
     sh = self._sh
     r90 = vist.MatrixTransform()
     r90.rotate(90, (0, 0, 1))
     rx180 = vist.MatrixTransform()
     rx180.rotate(180, (1, 0, 0))
     # Sagittal transformation :
     norm_sagit = vist.STTransform(scale=(1. / sh[1], 1. / sh[2], 1.),
                                   translate=(-1., 0., 0.))
     tf_sagit = vist.ChainTransform([norm_sagit, r90, rx180])
     self._im_sagit.transform = tf_sagit
     # Coronal transformation :
     norm_coron = vist.STTransform(scale=(1. / sh[0], 1. / sh[2], 1.),
                                   translate=(0., 0., 0.))
     tf_coron = vist.ChainTransform([norm_coron, r90, rx180])
     self._im_coron.transform = tf_coron
     # Axial transformation :
     norm_axis = vist.STTransform(scale=(2. / sh[1], 2. / sh[0], 1.),
                                  translate=(-1., 0., 0.))
     tf_axial = vist.ChainTransform([norm_axis, rx180])
     self._im_axial.transform = tf_axial
Beispiel #3
0
def test_transform_chain():
    # Make dummy classes for easier distinguishing the transforms

    class DummyTrans(tr.BaseTransform):
        glsl_map = "vec4 trans(vec4 pos) {return pos;}"
        glsl_imap = "vec4 trans(vec4 pos) {return pos;}"

    class TransA(DummyTrans):
        pass

    class TransB(DummyTrans):
        pass

    class TransC(DummyTrans):
        pass

    # Create test transforms
    a, b, c = TransA(), TransB(), TransC()

    # Test Chain creation
    assert tr.ChainTransform().transforms == []
    assert tr.ChainTransform(a).transforms == [a]
    assert tr.ChainTransform(a, b).transforms == [a, b]
    assert tr.ChainTransform(a, b, c, a).transforms == [a, b, c, a]

    # Test composition by multiplication
    assert_chain_objects(a * b, tr.ChainTransform(a, b))
    assert_chain_objects(a * b * c, tr.ChainTransform(a, b, c))
    assert_chain_objects(a * b * c * a, tr.ChainTransform(a, b, c, a))

    # Test adding/prepending to transform
    chain = tr.ChainTransform()
    chain.append(a)
    assert chain.transforms == [a]
    chain.append(b)
    assert chain.transforms == [a, b]
    chain.append(c)
    assert chain.transforms == [a, b, c]
    chain.prepend(b)
    assert chain.transforms == [b, a, b, c]
    chain.prepend(c)
    assert chain.transforms == [c, b, a, b, c]

    # Test simplifying
    t1 = tr.STTransform(scale=(2, 3))
    t2 = tr.STTransform(translate=(3, 4))
    t3 = tr.STTransform(translate=(3, 4))
    # Create multiplied versions
    t123 = t1 * t2 * t3
    t321 = t3 * t2 * t1
    c123 = tr.ChainTransform(t1, t2, t3)
    c321 = tr.ChainTransform(t3, t2, t1)
    c123s = c123.simplified
    c321s = c321.simplified
    #
    assert isinstance(t123, tr.STTransform)  # or the test is useless
    assert isinstance(t321, tr.STTransform)  # or the test is useless
    assert isinstance(c123s, tr.ChainTransform)  # or the test is useless
    assert isinstance(c321s, tr.ChainTransform)  # or the test is useless

    # Test Mapping
    t1 = tr.STTransform(scale=(2, 3))
    t2 = tr.STTransform(translate=(3, 4))
    chain1 = tr.ChainTransform(t1, t2)
    chain2 = tr.ChainTransform(t2, t1)
    #
    assert chain1.transforms == [t1, t2]  # or the test is useless
    assert chain2.transforms == [t2, t1]  # or the test is useless
    #
    m12 = (t1 * t2).map((1, 1)).tolist()
    m21 = (t2 * t1).map((1, 1)).tolist()
    m12_ = chain1.map((1, 1)).tolist()
    m21_ = chain2.map((1, 1)).tolist()
    #
    #print(m12, m21, m12_, m21_)
    assert m12 != m21
    assert m12 == m12_
    assert m21 == m21_

    # Test shader map
    t1 = tr.STTransform(scale=(2, 3))
    t2 = tr.STTransform(translate=(3, 4))
    chain = tr.ChainTransform(t1, t2)
    #
    funcs = chain.shader_map().dependencies()
    funcsi = chain.shader_imap().dependencies()
    #
    assert t1.shader_map() in funcs
    assert t2.shader_map() in funcs
    assert t1.shader_imap() in funcsi
    assert t2.shader_imap() in funcsi
Beispiel #4
0
    def __init__(self,
                 xyz=None,
                 channels=None,
                 system='cartesian',
                 unit='degree',
                 title=None,
                 title_color='black',
                 title_size=20.,
                 line_color='black',
                 line_width=4.,
                 chan_size=12.,
                 chan_offset=(0., 0., 0.),
                 chan_mark_color='white',
                 chan_mark_symbol='disc',
                 chan_txt_color='black',
                 bgcolor='white',
                 cbar=True,
                 cb_txt_size=10.,
                 margin=.05,
                 parent=None):
        """Init."""
        # ======================== VARIABLES ========================
        self._bgcolor = color2vb(bgcolor)
        scale = 800.  # fix GL bugs for small plots
        pos = np.zeros((1, 3), dtype=np.float32)
        # Colors :
        title_color = color2vb(title_color)
        line_color = color2vb(line_color)
        chan_txt_color = color2vb(chan_txt_color)
        self._chan_mark_color = color2vb(chan_mark_color)
        self._chan_mark_symbol = chan_mark_symbol
        # Disc interpolation :
        self._interp = .1
        self._pix = 64
        csize = int(self._pix / self._interp) if self._interp else self._pix
        l = csize / 2  # noqa

        # ======================== NODES ========================
        # Main topoplot node :
        self.node = scene.Node(name='Topoplot', parent=parent)
        self.node.transform = vist.STTransform(scale=[scale] * 3)
        # Headset + channels :
        self.node_headfull = scene.Node(name='HeadChan', parent=self.node)
        # Headset node :
        self.node_head = scene.Node(name='Headset', parent=self.node_headfull)
        # Channel node :
        self.node_chan = scene.Node(name='Channels', parent=self.node_headfull)
        self.node_chan.transform = vist.STTransform(translate=(0., 0., -10.))
        # Cbar node :
        self.node_cbar = scene.Node(name='Channels', parent=self.node)
        # Dictionaries :
        kw_line = {
            'width': line_width,
            'color': line_color,
            'parent': self.node_head
        }

        # ======================== PARENT VISUALS ========================
        # Main disc :
        self.disc = visuals.Image(pos=pos,
                                  name='Disc',
                                  parent=self.node_head,
                                  interpolation='bilinear')
        # Title :
        self.title = visuals.Text(text=title,
                                  pos=(0., .6, 0.),
                                  name='Title',
                                  parent=self.node,
                                  font_size=title_size,
                                  color=title_color,
                                  bold=True)
        self.title.font_size *= 1.1

        # ======================== HEAD / NOSE / EAR ========================
        # ------------------ HEAD ------------------
        # Head visual :
        self.head = visuals.Line(pos=pos, name='Head', **kw_line)
        # Head circle :
        theta = np.arange(0, 2 * np.pi, 0.001)
        head = np.full((len(theta), 3), -1., dtype=np.float32)
        head[:, 0] = l * (1. + np.cos(theta))
        head[:, 1] = l * (1. + np.sin(theta))
        self.head.set_data(pos=head)

        # ------------------ NOSE ------------------
        # Nose visual :
        self.nose = visuals.Line(pos=pos, name='Nose', **kw_line)
        # Nose data :
        wn, hn = csize * 50. / 512., csize * 30. / 512.
        nose = np.array([[l - wn, 2 * l - wn, 2.], [l, 2 * l + hn, 2.],
                         [l, 2 * l + hn, 2.], [l + wn, 2 * l - wn, 2.]])
        self.nose.set_data(pos=nose, connect='segments')

        # ------------------ EAR ------------------
        we, he = csize * 10. / 512., csize * 30. / 512.
        ye = l + he * np.sin(theta)
        # Ear left data :
        self.earL = visuals.Line(pos=pos, name='EarLeft', **kw_line)
        # Ear left visual :
        ear_l = np.full((len(theta), 3), 3., dtype=np.float32)
        ear_l[:, 0] = 2 * l + we * np.cos(theta)
        ear_l[:, 1] = ye
        self.earL.set_data(pos=ear_l)

        # Ear right visual :
        self.earR = visuals.Line(pos=pos, name='EarRight', **kw_line)
        # Ear right data :
        ear_r = np.full((len(theta), 3), 3., dtype=np.float32)
        ear_r[:, 0] = 0. + we * np.cos(theta)
        ear_r[:, 1] = ye
        self.earR.set_data(pos=ear_r)

        # ================== CHANNELS ==================
        # Channel's markers :
        self.chanMarkers = visuals.Markers(pos=pos,
                                           name='ChanMarkers',
                                           parent=self.node_chan)
        # Channel's text :
        self.chanText = visuals.Text(pos=pos,
                                     name='ChanText',
                                     parent=self.node_chan,
                                     anchor_x='center',
                                     color=chan_txt_color,
                                     font_size=chan_size)

        # ================== CAMERA ==================
        self.rect = ((-scale / 2) * (1 + margin), (-scale / 2) * (1 + margin),
                     scale * (1. + cbar * .3 + margin),
                     scale * (1.11 + margin))

        # ================== CBAR ==================
        if cbar:
            self.cbar = CbarVisual(cbtxtsz=1.2 * cb_txt_size,
                                   txtsz=cb_txt_size,
                                   txtcolor=title_color,
                                   cbtxtsh=2.,
                                   parent=self.node_cbar)
            self.node_cbar.transform = vist.STTransform(scale=(.6, .4, 1.),
                                                        translate=(.6, 0., 0.))

        # ================== COORDINATES ==================
        auto = self._get_channel_coordinates(xyz, channels, system, unit)
        if auto:
            eucl = np.sqrt(self._xyz[:, 0]**2 + self._xyz[:, 1]**2).max()
            self.node_head.transform = vpnormalize(head, dist=2 * eucl)
            # Rescale between (-1:1, -1:1) = circle :
            circle = vist.STTransform(scale=(.5 / eucl, .5 / eucl, 1.))
            self.node_headfull.transform = circle
            # Text translation :
            tr = np.array([0., .8, 0.]) + np.array(chan_offset)
        else:
            # Get coordinates of references along the x and y-axis :
            ref_x, ref_y = self._get_ref_coordinates()
            # Recenter the topoplot :
            t = vist.ChainTransform()
            t.prepend(vprecenter(head))
            # Rescale (-ref_x:ref_x, -ref_y:ref_y) (ref_x != ref_y => ellipse)
            coef_x = 2 * ref_x / head[:, 0].max()
            coef_y = 2 * ref_y / head[:, 1].max()
            t.prepend(vist.STTransform(scale=(coef_x, coef_y, 1.)))
            self.node_head.transform = t
            # Rescale between (-1:1, -1:1) = circle :
            circle = vist.STTransform(scale=(.5 / ref_x, .5 / ref_y, 1.))
            self.node_headfull.transform = circle
            # Text translation :
            tr = np.array([0., .04, 0.]) + np.array(chan_offset)
        self.chanText.transform = vist.STTransform(translate=tr)

        # ================== GRID INTERPOLATION ==================
        # Interpolation vectors :
        x = y = np.arange(0, self._pix, 1)
        xnew = ynew = np.arange(0, self._pix, self._interp)

        # Grid interpolation function :
        def _grid_interpolation(grid):
            f = interp2d(x, y, grid, kind='linear')
            return f(xnew, ynew)

        self._grid_interpolation = _grid_interpolation