class LayerWrapper(HasTraits):
    cmap = Enum(*cm.cmapnames, default='gist_rainbow')
    clim = ListFloat([0, 1])
    alpha = Float(1.0)
    visible = Bool(True)
    method = Enum(*ENGINES.keys())
    engine = Instance(layers.BaseLayer)
    dsname = CStr('output')

    def __init__(self,
                 pipeline,
                 method='points',
                 ds_name='',
                 cmap='gist_rainbow',
                 clim=[0, 1],
                 alpha=1.0,
                 visible=True,
                 method_args={}):
        self._pipeline = pipeline
        #self._namespace=getattr(pipeline, 'namespace', {})
        #self.dsname = None
        self.engine = None

        self.cmap = cmap
        self.clim = clim
        self.alpha = alpha

        self.visible = visible

        self.on_update = dispatch.Signal()

        self.on_trait_change(lambda: self.on_update.send(self), 'visible')
        self.on_trait_change(self.update, 'cmap, clim, alpha, dsname')
        self.on_trait_change(self._set_method, 'method')

        #self.set_datasource(ds_name)
        self.dsname = ds_name

        self._eng_params = dict(method_args)
        self.method = method

        self._pipeline.onRebuild.connect(self.update)

    @property
    def _namespace(self):
        return self._pipeline.layer_datasources

    @property
    def bbox(self):
        return self.engine.bbox

    @property
    def colour_map(self):
        return self.engine.colour_map

    @property
    def data_source_names(self):
        names = []  #'']
        for k, v in self._namespace.items():
            names.append(k)
            if isinstance(v, tabular.ColourFilter):
                for c in v.getColourChans():
                    names.append('.'.join([k, c]))

        return names

    @property
    def datasource(self):
        if self.dsname == '':
            return self._pipeline

        parts = self.dsname.split('.')
        if len(parts) == 2:
            # special case - permit access to channels using dot notation
            # NB: only works if our underlying datasource is a ColourFilter
            ds, channel = parts
            return self._namespace.get(ds, None).get_channel_ds(channel)
        else:
            return self._namespace.get(self.dsname, None)

    def _set_method(self):
        if self.engine:
            self._eng_params = self.engine.get('point_size', 'vertexColour')
            #print(eng_params)

        self.engine = ENGINES[self.method](self._context)
        self.engine.set(**self._eng_params)
        self.engine.on_trait_change(self._update, 'vertexColour')
        self.engine.on_trait_change(self.update)

        self.update()

    # def set_datasource(self, ds_name):
    #     self._dsname = ds_name
    #
    #     self.update()

    def _update(self, *args, **kwargs):
        cdata = self._get_cdata()
        self.clim = [float(cdata.min()), float(cdata.max())]
        #self.update(*args, **kwargs)

    def update(self, *args, **kwargs):
        print('lw update')
        if not (self.engine is None or self.datasource is None):
            self.engine.update_from_datasource(self.datasource,
                                               getattr(cm, self.cmap),
                                               self.clim, self.alpha)
            self.on_update.send(self)

    def render(self, gl_canvas):
        if self.visible:
            self.engine.render(gl_canvas)

    def _get_cdata(self):
        try:
            cdata = self.datasource[self.engine.vertexColour]
        except KeyError:
            cdata = np.array([0, 1])

        return cdata

    @property
    def default_view(self):
        from traitsui.api import View, Item, Group, InstanceEditor, EnumEditor
        from PYME.ui.custom_traits_editors import HistLimitsEditor, CBEditor

        return View(
            [
                Group([
                    Item('dsname',
                         label='Data',
                         editor=EnumEditor(values=self.data_source_names)),
                ]),
                Item('method'),
                #Item('_'),
                Group([
                    Item('engine',
                         style='custom',
                         show_label=False,
                         editor=InstanceEditor(
                             view=self.engine.view(self.datasource.keys()))),
                ]),
                #Item('engine.color_key', editor=CBEditor(choices=self.datasource.keys())),
                Group([
                    Item('clim',
                         editor=HistLimitsEditor(data=self._get_cdata),
                         show_label=False),
                ]),
                Group([
                    Item('cmap', label='LUT'),
                    Item('alpha'),
                    Item('visible')
                ],
                      orientation='horizontal',
                      layout='flow')
            ], )
        #buttons=['OK', 'Cancel'])

    def default_traits_view(self):
        return self.default_view
class PointCloudRenderLayer(EngineLayer):
    """
    A layer for viewing point-cloud data, using one of 3 engines (indicated above)
    
    """

    # properties to show in the GUI. Note that we also inherit 'visible' from BaseLayer
    vertexColour = CStr('', desc='Name of variable used to colour our points')
    point_size = Float(30.0, desc='Rendered size of the points in nm')
    cmap = Enum(*cm.cmapnames,
                default='gist_rainbow',
                desc='Name of colourmap used to colour points')
    clim = ListFloat(
        [0, 1],
        desc='How our variable should be scaled prior to colour mapping')
    alpha = Float(1.0, desc='Point tranparency')
    method = Enum(*ENGINES.keys(), desc='Method used to display points')
    dsname = CStr(
        'output',
        desc=
        'Name of the datasource within the pipeline to use as a source of points'
    )
    _datasource_keys = List()
    _datasource_choices = List()

    def __init__(self,
                 pipeline,
                 method='points',
                 dsname='',
                 context=None,
                 **kwargs):
        EngineLayer.__init__(self, context=context, **kwargs)
        self._pipeline = pipeline
        self.engine = None
        self.cmap = 'gist_rainbow'

        self.x_key = 'x'  #TODO - make these traits?
        self.y_key = 'y'
        self.z_key = 'z'

        self.xn_key = 'xn'
        self.yn_key = 'yn'
        self.zn_key = 'zn'

        self._bbox = None

        # define a signal so that people can be notified when we are updated (currently used to force a redraw when
        # parameters change)
        self.on_update = dispatch.Signal()

        # define responses to changes in various traits
        self.on_trait_change(self._update, 'vertexColour')
        self.on_trait_change(lambda: self.on_update.send(self), 'visible')
        self.on_trait_change(self.update,
                             'cmap, clim, alpha, dsname, point_size')
        self.on_trait_change(self._set_method, 'method')

        # update any of our traits which were passed as command line arguments
        self.set(**kwargs)

        # update datasource name and method
        #logger.debug('Setting dsname and method')
        self.dsname = dsname
        self.method = method

        self._set_method()

        # if we were given a pipeline, connect ourselves to the onRebuild signal so that we can automatically update
        # ourselves
        if not self._pipeline is None:
            self._pipeline.onRebuild.connect(self.update)

    @property
    def datasource(self):
        """
        Return the datasource we are connected to (through our dsname property).
        """
        return self._pipeline.get_layer_data(self.dsname)

    def _set_method(self):
        #logger.debug('Setting layer method to %s' % self.method)
        self.engine = ENGINES[self.method](self._context)
        self.update()

    def _get_cdata(self):
        try:
            cdata = self.datasource[self.vertexColour]
        except KeyError:
            cdata = np.array([0, 1])

        return cdata

    def _update(self, *args, **kwargs):
        cdata = self._get_cdata()
        self.clim = [float(cdata.min()), float(cdata.max())]
        #self.update(*args, **kwargs)

    def update(self, *args, **kwargs):
        print('lw update')
        self._datasource_choices = self._pipeline.layer_data_source_names
        if not self.datasource is None:
            self._datasource_keys = sorted(self.datasource.keys())

        if not (self.engine is None or self.datasource is None):
            self.update_from_datasource(self.datasource)
            self.on_update.send(self)

    @property
    def bbox(self):
        return self._bbox

    def update_from_datasource(self, ds):
        x, y = ds[self.x_key], ds[self.y_key]

        if not self.z_key is None:
            try:
                z = ds[self.z_key]
            except KeyError:
                z = 0 * x
        else:
            z = 0 * x

        if not self.vertexColour == '':
            c = ds[self.vertexColour]
        else:
            c = 0 * x

        if self.xn_key in ds.keys():
            xn, yn, zn = ds[self.xn_key], ds[self.yn_key], ds[self.zn_key]
            self.update_data(x,
                             y,
                             z,
                             c,
                             cmap=getattr(cm, self.cmap),
                             clim=self.clim,
                             alpha=self.alpha,
                             xn=xn,
                             yn=yn,
                             zn=zn)
        else:
            self.update_data(x,
                             y,
                             z,
                             c,
                             cmap=getattr(cm, self.cmap),
                             clim=self.clim,
                             alpha=self.alpha)

    def update_data(self,
                    x=None,
                    y=None,
                    z=None,
                    colors=None,
                    cmap=None,
                    clim=None,
                    alpha=1.0,
                    xn=None,
                    yn=None,
                    zn=None):
        self._vertices = None
        self._normals = None
        self._colors = None
        self._color_map = None
        self._color_limit = 0
        self._alpha = 0
        if x is not None and y is not None and z is not None:
            vertices = np.vstack((x.ravel(), y.ravel(), z.ravel()))
            vertices = vertices.T.ravel().reshape(len(x.ravel()), 3)

            if not xn is None:
                normals = np.vstack(
                    (xn.ravel(), yn.ravel(),
                     zn.ravel())).T.ravel().reshape(len(x.ravel()), 3)
            else:
                normals = -0.69 * np.ones(vertices.shape)

            self._bbox = np.array(
                [x.min(), y.min(),
                 z.min(), x.max(),
                 y.max(), z.max()])
        else:
            vertices = None
            normals = None
            self._bbox = None

        if clim is not None and colors is not None and clim is not None:
            cs_ = ((colors - clim[0]) / (clim[1] - clim[0]))
            cs = cmap(cs_)
            cs[:, 3] = alpha

            cs = cs.ravel().reshape(len(colors), 4)
        else:
            #cs = None
            if not vertices is None:
                cs = np.ones((vertices.shape[0], 4), 'f')
            else:
                cs = None
            color_map = None
            color_limit = None

        self.set_values(vertices, normals, cs, cmap, clim, alpha)

    def set_values(self,
                   vertices=None,
                   normals=None,
                   colors=None,
                   color_map=None,
                   color_limit=None,
                   alpha=None):
        if vertices is not None:
            self._vertices = vertices
        if normals is not None:
            self._normals = normals
        if color_map is not None:
            self._color_map = color_map
        if colors is not None:
            self._colors = colors
        if color_limit is not None:
            self._color_limit = color_limit
        if alpha is not None:
            self._alpha = alpha

    def get_vertices(self):
        return self._vertices

    def get_normals(self):
        return self._normals

    def get_colors(self):
        return self._colors

    def get_color_map(self):
        return self._color_map

    @property
    def colour_map(self):
        return self._color_map

    def get_color_limit(self):
        return self._color_limit

    @property
    def default_view(self):
        from traitsui.api import View, Item, Group, InstanceEditor, EnumEditor, TextEditor
        from PYME.ui.custom_traits_editors import HistLimitsEditor, CBEditor

        return View([
            Group([
                Item('dsname',
                     label='Data',
                     editor=EnumEditor(name='_datasource_choices')),
            ]),
            Item('method'),
            Item(
                'vertexColour',
                editor=EnumEditor(name='_datasource_keys'),
                label='Colour',
                visible_when='cmap not in ["R", "G", "B", "C", "M","Y", "K"]'),
            Group(
                [
                    Item('clim',
                         editor=HistLimitsEditor(data=self._get_cdata,
                                                 update_signal=self.on_update),
                         show_label=False),
                ],
                visible_when='cmap not in ["R", "G", "B", "C", "M","Y", "K"]'),
            Group(
                Item('cmap', label='LUT'),
                Item('alpha',
                     visible_when=
                     "method in ['pointsprites', 'transparent_points']",
                     editor=TextEditor(auto_set=False,
                                       enter_set=True,
                                       evaluate=float)),
                Item('point_size',
                     editor=TextEditor(auto_set=False,
                                       enter_set=True,
                                       evaluate=float)))
        ])
        #buttons=['OK', 'Cancel'])

    def default_traits_view(self):
        return self.default_view
Exemple #3
0
class TriangleRenderLayer(EngineLayer):
    """
    Layer for viewing triangle meshes.
    """
    # properties to show in the GUI. Note that we also inherit 'visible' from BaseLayer
    vertexColour = CStr('constant', desc='Name of variable used to colour our points')
    cmap = Enum(*cm.cmapnames, default='gist_rainbow', desc='Name of colourmap used to colour faces')
    clim = ListFloat([0, 1], desc='How our variable should be scaled prior to colour mapping')
    alpha = Float(1.0, desc='Face tranparency')
    method = Enum(*ENGINES.keys(), desc='Method used to display faces')
    normal_mode = Enum(['Per vertex', 'Per face'])
    dsname = CStr('output', desc='Name of the datasource within the pipeline to use as a source of triangles (should be a TriangularMesh object)')
    _datasource_choices = List()
    _datasource_keys = List()

    def __init__(self, pipeline, method='wireframe', dsname='', context=None, **kwargs):
        EngineLayer.__init__(self, context=context, **kwargs)
        self._pipeline = pipeline
        self.engine = None
        self.cmap = 'gist_rainbow'

        self.x_key = 'x'  # TODO - make these traits?
        self.y_key = 'y'
        self.z_key = 'z'

        self.xn_key = 'xn'
        self.yn_key = 'yn'
        self.zn_key = 'zn'

        self._bbox = None

        # define a signal so that people can be notified when we are updated (currently used to force a redraw when
        # parameters change)
        self.on_update = dispatch.Signal()

        # define responses to changes in various traits
        self.on_trait_change(self._update, 'vertexColour')
        self.on_trait_change(lambda: self.on_update.send(self), 'visible')
        self.on_trait_change(self.update, 'cmap, clim, alpha, dsname, normal_mode')
        self.on_trait_change(self._set_method, 'method')

        # update any of our traits which were passed as command line arguments
        self.set(**kwargs)

        # update datasource and method
        self.dsname = dsname
        if self.method == method:
            #make sure we still call _set_method even if we start with the default method
            self._set_method()
        else:
            self.method = method

        # if we were given a pipeline, connect ourselves to the onRebuild signal so that we can automatically update
        # ourselves
        if not self._pipeline is None:
            self._pipeline.onRebuild.connect(self.update)

    @property
    def datasource(self):
        """
        Return the datasource we are connected to (does not go through the pipeline for triangles_mesh).
        """
        return self._pipeline.get_layer_data(self.dsname)
        #return self.datasource
    
    @property
    def _ds_class(self):
        # from PYME.experimental import triangle_mesh
        from PYME.experimental import _triangle_mesh as triangle_mesh
        return triangle_mesh.TrianglesBase

    def _set_method(self):
        self.engine = ENGINES[self.method](self._context)
        self.update()

    def _get_cdata(self):
        try:
            cdata = self.datasource[self.vertexColour]
        except (KeyError, TypeError):
            cdata = np.array([0, 1])

        return cdata

    def _update(self, *args, **kwargs):
        #pass
        cdata = self._get_cdata()
        self.clim = [float(cdata.min()), float(cdata.max())]
        self.update(*args, **kwargs)

    def update(self, *args, **kwargs):
        self._datasource_choices = [k for k, v in self._pipeline.dataSources.items() if isinstance(v, self._ds_class)]
        
        if not self.datasource is None:
            dks = ['constant',]
            if hasattr(self.datasource, 'keys'):
                 dks = dks + sorted(self.datasource.keys())
            self._datasource_keys = dks
        
        if not (self.engine is None or self.datasource is None):
            print('lw update')
            self.update_from_datasource(self.datasource)
            self.on_update.send(self)

    @property
    def bbox(self):
        return self._bbox

    def update_from_datasource(self, ds):
        """
        Pulls vertices/normals from a binary STL file. See PYME.IO.FileUtils.stl for more info. Calls update_data on the input.

        Parameters
        ----------
        ds :
            PYME.experimental.triangular_mesh.TriangularMesh object

        Returns
        -------
        None
        """
        #t = ds.vertices[ds.faces]
        #n = ds.vertex_normals[ds.faces]
        
        x, y, z = ds.vertices[ds.faces].reshape(-1, 3).T
        
        if self.normal_mode == 'Per vertex':
            xn, yn, zn = ds.vertex_normals[ds.faces].reshape(-1, 3).T
        else:
            xn, yn, zn = np.repeat(ds.face_normals.T, 3, axis=1)
            
        if self.vertexColour in ['', 'constant']:
            c = np.ones(len(x))
            clim = [0, 1]
        #elif self.vertexColour == 'vertex_index':
        #    c = np.arange(0, len(x))
        else:
            c = ds[self.vertexColour][ds.faces].ravel()
            clim = self.clim

        cmap = getattr(cm, self.cmap)
        alpha = float(self.alpha)

        # Do we have coordinates? Concatenate into vertices.
        if x is not None and y is not None and z is not None:
            vertices = np.vstack((x.ravel(), y.ravel(), z.ravel()))
            self._vertices = vertices.T.ravel().reshape(len(x.ravel()), 3)

            if not xn is None:
                self._normals = np.vstack((xn.ravel(), yn.ravel(), zn.ravel())).T.ravel().reshape(len(x.ravel()), 3)
            else:
                self._normals = -0.69 * np.ones(self._vertices.shape)

            self._bbox = np.array([x.min(), y.min(), z.min(), x.max(), y.max(), z.max()])
        else:
            self._bbox = None

        # TODO: This temporarily sets all triangles to the color red. User should be able to select color.
        if c is None:
            c = np.ones(self._vertices.shape[0]) * 255  # vector of pink
            
        

        if clim is not None and c is not None and cmap is not None:
            cs_ = ((c - clim[0]) / (clim[1] - clim[0]))
            cs = cmap(cs_)

            if self.method in ['flat', 'tessel']:
                alpha = cs_ * alpha
            
            cs[:, 3] = alpha
            
            if self.method == 'tessel':
                cs = np.power(cs, 0.333)

            self._colors = cs.ravel().reshape(len(c), 4)
        else:
            # cs = None
            if not self._vertices is None:
                self._colors = np.ones((self._vertices.shape[0], 4), 'f')
            
        self._alpha = alpha
        self._color_map = cmap
        self._color_limit = clim


    def get_vertices(self):
        return self._vertices

    def get_normals(self):
        return self._normals

    def get_colors(self):
        return self._colors

    def get_color_map(self):
        return self._color_map

    @property
    def colour_map(self):
        return self._color_map

    def get_color_limit(self):
        return self._color_limit

    @property
    def default_view(self):
        from traitsui.api import View, Item, Group, InstanceEditor, EnumEditor
        from PYME.ui.custom_traits_editors import HistLimitsEditor, CBEditor

        return View([Group([Item('dsname', label='Data', editor=EnumEditor(name='_datasource_choices')), ]),
                     Item('method'),
                     Item('normal_mode', visible_when='method=="shaded"'),
                     Item('vertexColour', editor=EnumEditor(name='_datasource_keys'), label='Colour'),
                     Group([Item('clim', editor=HistLimitsEditor(data=self._get_cdata), show_label=False), ], visible_when='vertexColour != "constant"'),
                     Group([Item('cmap', label='LUT'),
                            Item('alpha', visible_when='method in ["flat", "tessel"]')
                            ])
                     ], )
        # buttons=['OK', 'Cancel'])

    def default_traits_view(self):
        return self.default_view
Exemple #4
0
class ImageRenderLayer(EngineLayer):
    """
    Layer for viewing images.
    """
    # properties to show in the GUI. Note that we also inherit 'visible' from BaseLayer
    cmap = Enum(*cm.cmapnames, default='gray', desc='Name of colourmap used to colour faces')
    clim = ListFloat([0, 1], desc='How our data should be scaled prior to colour mapping')
    alpha = Float(1.0, desc='Tranparency')
    method = Enum(*ENGINES.keys(), desc='Method used to display image')
    dsname = CStr('output', desc='Name of the datasource within the pipeline to use as an image')
    channel = Int(0)
    slice = Int(0)
    z_pos = Float(0)
    _datasource_choices = List()
    _datasource_keys = List()

    def __init__(self, pipeline, method='image', dsname='', display_opts=None, context=None, **kwargs):
        EngineLayer.__init__(self, context=context, **kwargs)
        self._pipeline = pipeline
        self.engine = None
        self.cmap = 'gray'

        self._bbox = None
        self._do = display_opts #a dh5view display_options instance - if provided, this over-rides the the clim, cmap properties
        
        self._im_key = None

        # define a signal so that people can be notified when we are updated (currently used to force a redraw when
        # parameters change)
        self.on_update = dispatch.Signal()

        # define responses to changes in various traits
        #self.on_trait_change(self._update, 'vertexColour')
        self.on_trait_change(lambda: self.on_update.send(self), 'visible')
        self.on_trait_change(self.update, 'cmap, clim, alpha, dsname')
        self.on_trait_change(self._set_method, 'method')

        # update any of our traits which were passed as command line arguments
        self.set(**kwargs)

        # update datasource and method
        self.dsname = dsname
        if self.method == method:
            #make sure we still call _set_method even if we start with the default method
            self._set_method()
        else:
            self.method = method

        # if we were given a pipeline, connect ourselves to the onRebuild signal so that we can automatically update
        # ourselves
        if (not self._pipeline is None) and hasattr(pipeline, 'onRebuild'):
            self._pipeline.onRebuild.connect(self.update)

    @property
    def datasource(self):
        """
        Return the datasource we are connected to (does not go through the pipeline for triangles_mesh).
        """
        try:
            return self._pipeline.get_layer_data(self.dsname)
        except AttributeError:
            #fallback if pipeline is a dictionary
            return self._pipeline[self.dsname]
        #return self.datasource
    
    @property
    def _ds_class(self):
        # from PYME.experimental import triangle_mesh
        from PYME.IO import image
        return image.ImageStack

    def _set_method(self):
        self.engine = ENGINES[self.method](self._context)
        self.update()


    # def _update(self, *args, **kwargs):
    #     #pass
    #     cdata = self._get_cdata()
    #     self.clim = [float(cdata.min()), float(cdata.max())]
    #     self.update(*args, **kwargs)

    def update(self, *args, **kwargs):
        try:
            self._datasource_choices = [k for k, v in self._pipeline.dataSources.items() if isinstance(v, self._ds_class)]
        except AttributeError:
            self._datasource_choices = [k for k, v in self._pipeline.items() if
                                        isinstance(v, self._ds_class)]
        
        if not (self.engine is None or self.datasource is None):
            print('lw update')
            self.update_from_datasource(self.datasource)
            self.on_update.send(self)

    @property
    def bbox(self):
        return self._bbox
    
    def sync_to_display_opts(self, do=None):
        if (do is None):
            if not (self._do is None):
                do = self._do
            else:
                return

        o = do.Offs[self.channel]
        g = do.Gains[self.channel]
        clim = [o, o + 1.0 / g]

        cmap = do.cmaps[self.channel].name
        visible = do.show[self.channel]
        
        self.set(clim=clim, cmap=cmap, visible=visible)
        

    def update_from_datasource(self, ds):
        """

        Parameters
        ----------
        ds :
            PYME.IO.image.ImageStack object

        Returns
        -------
        None
        """

        
        #if self._do is not None:
            # Let display options (if provied) over-ride our settings (TODO - is this the right way to do this?)
        #    o = self._do.Offs[self.channel]
        #    g = self._do.Gains[self.channel]
        #    clim = [o, o + 1.0/g]
            #self.clim = clim
            
        #    cmap = self._do.cmaps[self.channel]
            #self.visible = self._do.show[self.channel]
        #else:
        
        clim = self.clim
        cmap = getattr(cm, self.cmap)
            
        alpha = float(self.alpha)
        
        c0, c1 = clim
        
        im_key = (self.dsname, self.slice, self.channel)
        
        if not self._im_key == im_key:
            self._im_key = im_key
            self._im = ds.data[:,:,self.slice, self.channel].astype('f4')# - c0)/(c1-c0)
        
            x0, y0, x1, y1, _, _ = ds.imgBounds.bounds

            self._bbox = np.array([x0, y0, 0, x1, y1, 0])
        
            self._bounds = [x0, y0, x1, y1]
            
        self._alpha = alpha
        self._color_map = cmap
        self._color_limit = clim

    def get_color_map(self):
        return self._color_map

    @property
    def colour_map(self):
        return self._color_map

    def get_color_limit(self):
        return self._color_limit
    
    def _get_cdata(self):
        return self._im.ravel()[::20]

    @property
    def default_view(self):
        from traitsui.api import View, Item, Group, InstanceEditor, EnumEditor
        from PYME.ui.custom_traits_editors import HistLimitsEditor, CBEditor

        return View([Group([Item('dsname', label='Data', editor=EnumEditor(name='_datasource_choices')), ]),
                     #Item('method'),
                     Group([Item('clim', editor=HistLimitsEditor(data=self._get_cdata), show_label=False), ]),
                     Group([Item('cmap', label='LUT'),
                            Item('alpha', visible_when='method in ["flat", "tessel"]')
                            ])
                     ], )
        # buttons=['OK', 'Cancel'])

    def default_traits_view(self):
        return self.default_view
class TrackRenderLayer(EngineLayer):
    """
    A layer for viewing tracking data

    """

    # properties to show in the GUI. Note that we also inherit 'visible' from BaseLayer
    vertexColour = CStr('', desc='Name of variable used to colour our points')
    cmap = Enum(*cm.cmapnames,
                default='gist_rainbow',
                desc='Name of colourmap used to colour points')
    clim = ListFloat(
        [0, 1],
        desc='How our variable should be scaled prior to colour mapping')
    alpha = Float(1.0, desc='Tranparency')
    line_width = Float(1.0, desc='Track line width')
    method = Enum(*ENGINES.keys(), desc='Method used to display tracks')
    clump_key = CStr('clumpIndex',
                     desc="Name of column containing the track identifier")
    dsname = CStr(
        'output',
        desc=
        'Name of the datasource within the pipeline to use as a source of points'
    )
    _datasource_keys = List()
    _datasource_choices = List()

    def __init__(self,
                 pipeline,
                 method='tracks',
                 dsname='',
                 context=None,
                 **kwargs):
        EngineLayer.__init__(self, context=context, **kwargs)
        self._pipeline = pipeline
        self.engine = None
        self.cmap = 'gist_rainbow'

        self.x_key = 'x'  #TODO - make these traits?
        self.y_key = 'y'
        self.z_key = 'z'

        self._bbox = None

        # define a signal so that people can be notified when we are updated (currently used to force a redraw when
        # parameters change)
        self.on_update = dispatch.Signal()

        # define responses to changes in various traits
        self.on_trait_change(self._update, 'vertexColour')
        self.on_trait_change(lambda: self.on_update.send(self), 'visible')
        self.on_trait_change(self.update,
                             'cmap, clim, alpha, dsname, clump_key')
        self.on_trait_change(self._set_method, 'method')

        # update any of our traits which were passed as command line arguments
        self.set(**kwargs)

        # update datasource name and method
        #logger.debug('Setting dsname and method')
        self.dsname = dsname
        self.method = method

        self._set_method()

        # if we were given a pipeline, connect ourselves to the onRebuild signal so that we can automatically update
        # ourselves
        if not self._pipeline is None:
            self._pipeline.onRebuild.connect(self.update)

    @property
    def datasource(self):
        """
        Return the datasource we are connected to (through our dsname property).
        """
        return self._pipeline.get_layer_data(self.dsname)

    def _set_method(self):
        #logger.debug('Setting layer method to %s' % self.method)
        self.engine = ENGINES[self.method](self._context)
        self.update()

    def _get_cdata(self):
        try:
            if isinstance(self.datasource, ClumpManager):
                cdata = []
                for track in self.datasource.all:
                    cdata.extend(track[self.vertexColour])
                cdata = np.array(cdata)
            else:
                # Assume tabular dataset
                cdata = self.datasource[self.vertexColour]
        except KeyError:
            cdata = np.array([0, 1])

        return cdata

    def _update(self, *args, **kwargs):
        cdata = self._get_cdata()
        self.clim = [float(np.nanmin(cdata)), float(np.nanmax(cdata))]
        #self.update(*args, **kwargs)

    def update(self, *args, **kwargs):
        print('lw update')
        self._datasource_choices = self._pipeline.layer_data_source_names
        if not self.datasource is None:
            if isinstance(self.datasource, ClumpManager):
                # Grab the keys from the first Track in the ClumpManager
                self._datasource_keys = sorted(self.datasource[0].keys())
            else:
                # Assume we have a tabular data source
                self._datasource_keys = sorted(self.datasource.keys())

        if not (self.engine is None or self.datasource is None):
            self.update_from_datasource(self.datasource)
            self.on_update.send(self)

    @property
    def bbox(self):
        return self._bbox

    def update_from_datasource(self, ds):
        if isinstance(ds, ClumpManager):
            x = []
            y = []
            z = []
            c = []
            self.clumpSizes = []

            # Copy data from tracks. This is already in clump order
            # thanks to ClumpManager
            for track in ds.all:
                x.extend(track['x'])
                y.extend(track['y'])
                z.extend(track['z'])
                self.clumpSizes.append(track.nEvents)

                if not self.vertexColour == '':
                    c.extend(track[self.vertexColour])
                else:
                    c.extend([0 for i in track['x']])

            x = np.array(x)
            y = np.array(y)
            z = np.array(z)
            c = np.array(c)

            # print(x,y,z,c)
            # print(x.shape,y.shape,z.shape,c.shape)

        else:
            # Assume tabular data source
            x, y = ds[self.x_key], ds[self.y_key]

            if not self.z_key is None:
                try:
                    z = ds[self.z_key]
                except KeyError:
                    z = 0 * x
            else:
                z = 0 * x

            if not self.vertexColour == '':
                c = ds[self.vertexColour]
            else:
                c = 0 * x

            # Work out clump start and finish indices
            # TODO - optimize / precompute????
            ci = ds[self.clump_key]

            NClumps = int(ci.max())

            clist = [[] for i in range(NClumps)]
            for i, cl_i in enumerate(ci):
                clist[int(cl_i - 1)].append(i)

            # This and self.clumpStarts are class attributes for
            # compatibility with the old Track rendering layer,
            # PYME.LMVis.gl_render3D.TrackLayer
            self.clumpSizes = [len(cl_i) for cl_i in clist]

            #reorder x, y, z, c in clump order
            I = np.hstack([np.array(cl) for cl in clist]).astype(np.int)

            x = x[I]
            y = y[I]
            z = z[I]
            c = c[I]

        self.clumpStarts = np.cumsum([
            0,
        ] + self.clumpSizes)

        #do normal vertex stuff
        vertices = np.vstack((x.ravel(), y.ravel(), z.ravel()))
        vertices = vertices.T.ravel().reshape(len(x.ravel()), 3)

        self._vertices = vertices
        self._normals = -0.69 * np.ones(vertices.shape)
        self._bbox = np.array(
            [x.min(), y.min(),
             z.min(), x.max(),
             y.max(), z.max()])

        clim = self.clim
        cmap = getattr(cm, self.cmap)

        if clim is not None:
            cs_ = ((c - clim[0]) / (clim[1] - clim[0]))
            cs = cmap(cs_)
            cs[:, 3] = float(self.alpha)

            self._colors = cs.ravel().reshape(len(c), 4)
        else:
            if not vertices is None:
                self._colors = np.ones((vertices.shape[0], 4), 'f')

        self._color_map = cmap
        self._color_limit = clim
        self._alpha = float(self.alpha)

    def get_vertices(self):
        return self._vertices

    def get_normals(self):
        return self._normals

    def get_colors(self):
        return self._colors

    def get_color_map(self):
        return self._color_map

    @property
    def colour_map(self):
        return self._color_map

    def get_color_limit(self):
        return self._color_limit

    @property
    def default_view(self):
        from traitsui.api import View, Item, Group, InstanceEditor, EnumEditor, TextEditor
        from PYME.ui.custom_traits_editors import HistLimitsEditor, CBEditor

        return View([
            Group([
                Item('dsname',
                     label='Data',
                     editor=EnumEditor(name='_datasource_choices')),
            ]),
            Item('method'),
            Item(
                'vertexColour',
                editor=EnumEditor(name='_datasource_keys'),
                label='Colour',
                visible_when='cmap not in ["R", "G", "B", "C", "M","Y", "K"]'),
            Group(
                [
                    Item('clim',
                         editor=HistLimitsEditor(data=self._get_cdata,
                                                 update_signal=self.on_update),
                         show_label=False),
                ],
                visible_when='cmap not in ["R", "G", "B", "C", "M","Y", "K"]'),
            Group(
                Item('cmap', label='LUT'),
                Item('alpha',
                     editor=TextEditor(auto_set=False,
                                       enter_set=True,
                                       evaluate=float)), Item('line_width'))
        ])
        #buttons=['OK', 'Cancel'])

    def default_traits_view(self):
        return self.default_view