Exemplo n.º 1
0
class PSFSettings(HasTraits):
    wavelength_nm = Float(700.)
    NA = Float(1.47)
    vectorial = Bool(False)
    zernike_modes = Dict()
    zernike_modes_lower = Dict()
    phases = List([0, .5, 1, 1.5])
    four_pi = Bool(False)

    def default_traits_view(self):
        from traitsui.api import View, Item
        #from PYME.ui.custom_traits_editors import CBEditor

        return View(Item(name='wavelength_nm'),
                    Item(name='NA'),
                    Item(name='vectorial'),
                    Item(name='four_pi', label='4Pi'),
                    Item(name='zernike_modes'),
                    Item(name='zernike_modes_lower',
                         visible_when='four_pi==True'),
                    Item(name='phases',
                         visible_when='four_pi==True',
                         label='phases/pi'),
                    resizable=True,
                    buttons=['OK'])
Exemplo n.º 2
0
class PointFeatureBase(ModuleBase):
    """
    common base class for feature extraction routines - implements normalisation and PCA routines
    """

    outputColumnName = CStr('features')
    columnForEachFeature = Bool(
        False
    )  #if true, outputs a column for each feature - useful for visualising

    normalise = Bool(True)  #subtract mean and divide by std. deviation

    PCA = Bool(
        True
    )  # reduce feature dimensionality by performing PCA - TODO - should this be a separate module and be chained instead?
    PCA_components = Int(3)  # 0 = same dimensionality as features

    def _process_features(self, data, features):
        from PYME.IO import tabular
        out = tabular.MappingFilter(data)
        out.mdh = getattr(data, 'mdh', None)

        if self.normalise:
            features = features - features.mean(0)[None, :]
            features = features / features.std(0)[None, :]

        if self.PCA:
            from sklearn.decomposition import PCA

            pca = PCA(n_components=(
                self.PCA_components if self.PCA_components > 0 else None
            )).fit(features)
            features = pca.transform(features)

            out.pca = pca  #save the pca object just in case we want to look at what the principle components are (this is hacky)

        out.addColumn(self.outputColumnName, features)

        if self.columnForEachFeature:
            for i in range(features.shape[1]):
                out.addColumn('feat_%d' % i, features[:, i])

        return out
Exemplo n.º 3
0
class ArithmaticFilter(ModuleBase):
    """
    Module with two image inputs and one image output
    
    Parameters
    ----------
    inputName0: PYME.IO.image.ImageStack
    inputName1: PYME.IO.image.ImageStack
    outputName: PYME.IO.image.ImageStack
    
    """
    inputName0 = Input('input')
    inputName1 = Input('input')
    outputName = Output('filtered_image')

    processFramesIndividually = Bool(False)

    def filter(self, image0, image1):
        if self.processFramesIndividually:
            filt_ims = []
            for chanNum in range(image0.data.shape[3]):
                out = []
                for i in range(image0.data.shape[2]):
                    d0 = image0.data[:, :, i, chanNum].squeeze().astype('f')
                    d1 = image1.data[:, :, i, chanNum].squeeze().astype('f')
                    out.append(
                        np.atleast_3d(
                            self.applyFilter(d0, d1, chanNum, i, image0)))
                filt_ims.append(np.concatenate(out, 2))
        else:
            filt_ims = []
            for chanNum in range(image0.data.shape[3]):
                d0 = image0.data[:, :, :, chanNum].squeeze().astype('f')
                d1 = image1.data[:, :, :, chanNum].squeeze().astype('f')
                filt_ims.append(
                    np.atleast_3d(self.applyFilter(d0, d1, chanNum, 0,
                                                   image0)))

        im = ImageStack(filt_ims, titleStub=self.outputName)
        im.mdh.copyEntriesFrom(image0.mdh)
        im.mdh['Parents'] = '%s, %s' % (image0.filename, image1.filename)

        self.completeMetadata(im)

        return im

    def execute(self, namespace):
        namespace[self.outputName] = self.filter(namespace[self.inputName0],
                                                 namespace[self.inputName1])

    def completeMetadata(self, im):
        pass
Exemplo n.º 4
0
class PointFeaturesPairwiseDist(PointFeatureBase):
    """
    Create a feature vector for each point in a point-cloud using a histogram of it's distances to all other points
    
    """
    inputLocalisations = Input('localisations')
    outputName = Output('features')

    binWidth = Float(100.)  # width of the bins in nm
    numBins = Int(20)  #number of bins (starting at 0)
    threeD = Bool(True)

    normaliseRelativeDensity = Bool(
        False
    )  # divide by the sum of all radial bins. If not performed, the first principle component will likely be average density

    def execute(self, namespace):
        from PYME.Analysis.points import DistHist
        points = namespace[self.inputLocalisations]

        if self.threeD:
            x, y, z = points['x'], points['y'], points['z']
            f = np.array([
                DistHist.distanceHistogram3D(x[i], y[i], z[i], x, y, z,
                                             self.numBins, self.binWidth)
                for i in xrange(len(x))
            ])
        else:
            x, y = points['x'], points['y']
            f = np.array([
                DistHist.distanceHistogram(x[i], y[i], x, y, self.numBins,
                                           self.binWidth)
                for i in xrange(len(x))
            ])

        namespace[self.outputName] = self._process_features(points, f)
Exemplo n.º 5
0
class EngineLayer(BaseLayer):
    """
    Base class for layers who delegate their rendering to an engine.
    """
    engine = Instance(BaseEngine)
    show_lut = Bool(True)

    def render(self, gl_canvas):
        if self.visible:
            return self.engine.render(gl_canvas, self)
        
    
    @abc.abstractmethod
    def get_vertices(self):
        """
        Provides the engine with a way of obtaining vertex data. Should be over-ridden in derived class
        
        Returns
        -------
        a numpy array of vertices suitable for passing to glVertexPointerf()

        """
        raise(NotImplementedError())

    @abc.abstractmethod
    def get_normals(self):
        """
        Provides the engine with a way of obtaining vertex data. Should be over-ridden in derived class

        Returns
        -------
        a numpy array of normals suitable for passing to glNormalPointerf()

        """
        raise (NotImplementedError())

    @abc.abstractmethod
    def get_colors(self):
        """
        Provides the engine with a way of obtaining vertex data. Should be over-ridden in derived class

        Returns
        -------
        a numpy array of vertices suitable for passing to glColorPointerf()

        """
        raise (NotImplementedError())
Exemplo n.º 6
0
class ExtractChannelByName(ModuleBase):
    """Extract one channel from an image using regular expression matching to image channel names - by default this is case insensitive"""
    inputName = Input('input')
    outputName = Output('filtered_image')

    channelNamePattern = CStr('channel0')
    caseInsensitive = Bool(True)

    def _matchChannels(self, channelNames):
        # we put this into its own function so that we can call it externally for testing
        import re
        flags = 0
        if self.caseInsensitive:
            flags |= re.I
        idxs = [
            i for i, c in enumerate(channelNames)
            if re.search(self.channelNamePattern, c, flags)
        ]
        return idxs

    def _pickChannel(self, image):
        channelNames = image.mdh['ChannelNames']
        idxs = self._matchChannels(channelNames)
        if len(idxs) < 1:
            raise RuntimeError(
                "Expression '%s' did not match any channel names" %
                self.channelNamePattern)
        if len(idxs) > 1:
            raise RuntimeError(
                ("Expression '%s' did match more than one channel name: " %
                 self.channelNamePattern) +
                ', '.join([channelNames[i] for i in idxs]))
        idx = idxs[0]

        chan = image.data[:, :, :, idx]

        im = ImageStack(chan, titleStub='Filtered Image')
        im.mdh.copyEntriesFrom(image.mdh)
        im.mdh['ChannelNames'] = [channelNames[idx]]
        im.mdh['Parent'] = image.filename

        return im

    def execute(self, namespace):
        namespace[self.outputName] = self._pickChannel(
            namespace[self.inputName])
Exemplo n.º 7
0
class BaseLayer(HasTraits):
    """
    This class represents a layer that should be rendered. It should represent a fairly high level concept of a layer -
    e.g. a Point-cloud of data coming from XX, or a Surface representation of YY. If such a layer can be rendered multiple
    different but similar ways (e.g. points/pointsprites or shaded/wireframe etc) which otherwise share common settings
    e.g. point size, point colour, etc ... these representations should be coded as one layer with a selectable rendering
    backend or 'engine' responsible for managing shaders and actually executing the opengl code. In this case use the
    `EngineLayer` class as a base
.
    In simpler cases, such as rendering an overlay it is acceptable for a layer to do it's own rendering and manage it's
    own shader. In this case, use `SimpleLayer` as a base.
    """
    visible = Bool(True)
    
    def __init__(self, context=None, **kwargs):
        self._context = context
        #HasTraits.__init__(**kwargs)
        
    @property
    def bbox(self):
        """Bounding box in form [x0,y0,z0, x1,y1,z1] (or none if a bounding box does not make sense for this layer)
        
        over-ride in derived classes
        """
        return None

    

    @abc.abstractmethod
    def render(self, gl_canvas):
        """
        Abstract render method to be over-ridden in derived classes. Should check self.visible before drawing anything.
        
        Parameters
        ----------
        gl_canvas : the canvas to draw to - an instance of PYME.LMVis.gl_render3D_shaders.LMGLShaderCanvas


        """
        pass
Exemplo n.º 8
0
class LabelByRegionProperty(Filter):
    """Asigns a region property to each contiguous region in the input mask.
    Optionally throws away all regions for which property is outside a given range.
    """
    regionProperty = Enum(['area', 'circularity', 'aspectratio'])
    filterByProperty = Bool(False)
    propertyMin = Float(0)
    propertyMax = Float(1e6)

    def applyFilter(self, data, chanNum, frNum, im):
        mask = data > 0.5
        labs, nlabs = ndimage.label(mask)
        rp = skimage.measure.regionprops(labs, None, cache=True)

        m2 = np.zeros_like(mask, dtype='float')
        objs = ndimage.find_objects(labs)
        for region in rp:
            oslices = objs[region.label - 1]
            r = labs[oslices] == region.label
            #print r.shape
            if self.regionProperty == 'area':
                propValue = region.area
            elif self.regionProperty == 'aspectratio':
                propValue = region.major_axis_length / region.minor_axis_length
            elif self.regionProperty == 'circularity':
                propValue = 4 * math.pi * region.area / (region.perimeter *
                                                         region.perimeter)
            if self.filterByProperty:
                if (propValue >= self.propertyMin) and (propValue <=
                                                        self.propertyMax):
                    m2[oslices] += r * propValue
            else:
                m2[oslices] += r * propValue

        return m2

    def completeMetadata(self, im):
        im.mdh['Labelling.Property'] = self.regionProperty
class DetectPSF(ModuleBase):
    """
        Detect PSF based on diff of gaussian
        Image dims in X, Y, Z, C where C are processed independently.
        Returns list of (X, Y, Z) per C
    """
    
    inputName = Input('input')
    
    min_sigma = Float(1.0)
    max_sigma = Float(3.0)
    sigma_ratio = Float(1.6)
    percent_threshold = Float(0.1)
    overlap = Float(0.5)
    exclude_border = Int(50)
    ignore_z = Bool(True)
    
    output_pos = Output('psf_pos')
#    output_img = Output('output')
    
    def execute(self, namespace):
        ims = namespace[self.inputName]
        
        pixel_size = ims.mdh['voxelsize.x']
        
        pos = list()
        counts = ims.data.shape[3]
        for c in np.arange(counts):
            mean_project = ims.data[:,:,:,c].mean(2).squeeze()
            mean_project[mean_project==2**16-1] = 200
            
            mean_project -= mean_project.min()
            mean_project /= mean_project.max()
            
            # if skimage is new enough to support exclude_border
            #blobs = feature.blob_dog(mean_project, self.min_sigma / pixel_size, self.max_sigma / pixel_size, overlap=self.overlap, threshold=self.percent_threshold*mean_project.max(), exclude_border=self.exclude_border)
            
            #otherwise:
            blobs = feature.blob_dog(mean_project, self.min_sigma / pixel_size, self.max_sigma / pixel_size, overlap=self.overlap, threshold=self.percent_threshold*mean_project.max())

            edge_mask = (blobs[:, 0] > self.exclude_border) & (blobs[:, 0] < mean_project.shape[0] - self.exclude_border)
            edge_mask &= (blobs[:, 1] > self.exclude_border) & (blobs[:, 1] < mean_project.shape[1] - self.exclude_border)
            blobs = blobs[edge_mask]
            # is list of x, y, sig
            if self.ignore_z:
                blobs = np.insert(blobs, 2, ims.data.shape[2]//2, axis=1)
            else:
                raise Exception("z centering not yet implemented")
            blobs = blobs.astype(np.int)
#            print blobs
                
            pos.append(blobs)
        namespace[self.output_pos] = pos

        if True:
            try:
#                from matplotlib import pyplot
                fig, axes = pyplot.subplots(1, counts, figsize=(4*counts, 3), squeeze=False)
                for c in np.arange(counts):
                    mean_project = ims.data[:,:,:,c].mean(2).squeeze()
                    mean_project[mean_project==2**16-1] = 200
                    axes[0, c].imshow(mean_project)
                    axes[0, c].set_axis_off()
                    for x, y, z, sig in pos[c]:
                        cir = pyplot.Circle((y, x), sig, color='red', linewidth=2, fill=False)
                        axes[0, c].add_patch(cir)
            except Exception as e:
                print e
Exemplo n.º 10
0
class ModuleCollection(HasTraits):
    modules = List()
    execute_on_invalidation = Bool(False)

    def __init__(self, *args, **kwargs):
        HasTraits.__init__(self, *args, **kwargs)

        self.namespace = {}

        # we open hdf files and don't necessarily read their contents into memory - these need to be closed when we
        # either delete the recipe, or clear the namespace
        self._open_input_files = []

        self.recipe_changed = dispatch.Signal()
        self.recipe_executed = dispatch.Signal()

    def invalidate_data(self):
        if self.execute_on_invalidation:
            self.execute()

    def clear(self):
        self.namespace.clear()

    def new_output_name(self, stub):
        count = len([k.startswith(stub) for k in self.namespace.keys()])

        if count == 0:
            return stub
        else:
            return '%s_%d' % (stub, count)

    def dependancyGraph(self):
        dg = {}

        #only add items to dependancy graph if they are not already in the namespace
        #calculated_objects = namespace.keys()

        for mod in self.modules:
            #print mod
            s = mod.inputs

            try:
                s.update(dg[mod])
            except KeyError:
                pass

            dg[mod] = s

            for op in mod.outputs:
                #if not op in calculated_objects:
                dg[op] = {
                    mod,
                }

        return dg

    def reverseDependancyGraph(self):
        dg = self.dependancyGraph()

        rdg = {}

        for k, vs in dg.items():
            for v in vs:
                vdeps = set()
                try:
                    vdeps = rdg[v]
                except KeyError:
                    pass

                vdeps.add(k)
                rdg[v] = vdeps

        return rdg

    def _getAllDownstream(self, rdg, keys):
        """get all the downstream items which depend on the given key"""

        downstream = set()

        next_level = set()

        for k in keys:
            try:
                next_level.update(rdg[k])
            except KeyError:
                pass

        if len(list(next_level)) > 0:

            downstream.update(next_level)

            downstream.update(self._getAllDownstream(rdg, list(next_level)))

        return downstream

    def prune_dependencies_from_namespace(self,
                                          keys_to_prune,
                                          keep_passed_keys=False):
        rdg = self.reverseDependancyGraph()

        if keep_passed_keys:
            downstream = list(self._getAllDownstream(rdg, list(keys_to_prune)))
        else:
            downstream = list(keys_to_prune) + list(
                self._getAllDownstream(rdg, list(keys_to_prune)))

        #print downstream

        for dsi in downstream:
            try:
                self.namespace.pop(dsi)
            except KeyError:
                #the output is not in our namespace, no need to prune
                pass
            except AttributeError:
                #we might not have our namespace defined yet
                pass

    def resolveDependencies(self):
        import toposort
        #build dependancy graph

        dg = self.dependancyGraph()

        #solve the dependency tree
        return toposort.toposort_flatten(dg, sort=False)

    def execute(self, **kwargs):
        #remove anything which is downstream from changed inputs
        #print self.namespace.keys()
        for k, v in kwargs.items():
            #print k, v
            try:
                if not (self.namespace[k] == v):
                    #input has changed
                    print('pruning: ', k)
                    self.prune_dependencies_from_namespace([k])
            except KeyError:
                #key wasn't in namespace previously
                print('KeyError')
                pass

        self.namespace.update(kwargs)

        exec_order = self.resolveDependencies()

        for m in exec_order:
            if isinstance(
                    m,
                    ModuleBase) and not m.outputs_in_namespace(self.namespace):
                try:
                    m.execute(self.namespace)
                except:
                    logger.exception("Error in recipe module: %s" % m)
                    raise

        self.recipe_executed.send_robust(self)

        if 'output' in self.namespace.keys():
            return self.namespace['output']

    @classmethod
    def fromMD(cls, md):
        c = cls()

        moduleNames = set([s.split('.')[0] for s in md.keys()])

        mc = []

        for mn in moduleNames:
            mod = all_modules[mn]()
            mod.set(**md[mn])
            mc.append(mod)

        #return cls(modules=mc)
        c.modules = mc

        return c

    def get_cleaned_module_list(self):
        l = []
        for mod in self.modules:
            #l.append({mod.__class__.__name__: mod.get()})

            ct = mod.class_traits()

            mod_traits_cleaned = {}
            for k, v in mod.get().items():
                if not k.startswith(
                        '_'
                ):  #don't save private data - this is usually used for caching etc ..,
                    try:
                        if (not (v == ct[k].default)) or (k.startswith(
                                'input')) or (k.startswith('output')):
                            #don't save defaults
                            if isinstance(v, dict) and not type(v) == dict:
                                v = dict(v)
                            elif isinstance(v, list) and not type(v) == list:
                                v = list(v)
                            elif isinstance(v, set) and not type(v) == set:
                                v = set(v)

                            mod_traits_cleaned[k] = v
                    except KeyError:
                        # for some reason we have a trait that shouldn't be here
                        pass

            l.append({module_names[mod.__class__]: mod_traits_cleaned})

        return l

    def toYAML(self):
        import yaml

        class MyDumper(yaml.SafeDumper):
            def represent_mapping(self, tag, value, flow_style=None):
                return super(MyDumper,
                             self).represent_mapping(tag, value, False)

        return yaml.dump(self.get_cleaned_module_list(), Dumper=MyDumper)

    def toJSON(self):
        import json
        return json.dumps(self.get_cleaned_module_list())

    def _update_from_module_list(self, l):
        """
        Update from a parsed yaml or json list of modules
        
        It probably makes no sense to call this directly as the format is pretty wack - a
        list of dictionarys each with a single entry, but that is how the yaml parses

        Parameters
        ----------
        l: list
            List of modules as obtained from parsing a yaml recipe,
            Each module is a dictionary mapping with a single e.g.
            [{'Filtering.Filter': {'filters': {'probe': [-0.5, 0.5]}, 'input': 'localizations', 'output': 'filtered'}}]

        Returns
        -------

        """
        mc = []

        if l is None:
            l = []

        for mdd in l:
            mn, md = list(mdd.items())[0]
            try:
                mod = all_modules[mn](self)
            except KeyError:
                # still support loading old recipes which do not use hierarchical names
                # also try and support modules which might have moved
                mod = _legacy_modules[mn.split('.')[-1]](self)

            mod.set(**md)
            mc.append(mod)

        self.modules = mc

        self.recipe_changed.send_robust(self)
        self.invalidate_data()

    @classmethod
    def _from_module_list(cls, l):
        """ A factory method which contains the common logic for loading/creating from either
        yaml or json. Do not call directly"""
        c = cls()
        c._update_from_module_list(l)

        return c

    @classmethod
    def fromYAML(cls, data):
        import yaml

        l = yaml.load(data)
        return cls._from_module_list(l)

    def update_from_yaml(self, data):
        """
        Update from a yaml formatted recipe description

        Parameters
        ----------
        data: str
            either yaml formatted text, or the path to a yaml file.

        Returns
        -------
        None

        """
        import os
        import yaml

        if os.path.isfile(data):
            with open(data) as f:
                data = f.read()

        l = yaml.load(data)
        return self._update_from_module_list(l)

    @classmethod
    def fromJSON(cls, data):
        import json
        return cls._from_module_list(json.loads(data))

    def add_module(self, module):
        self.modules.append(module)
        self.recipe_changed.send_robust(self)

    @property
    def inputs(self):
        ip = set()
        for mod in self.modules:
            ip.update({k for k in mod.inputs if k.startswith('in')})
        return ip

    @property
    def outputs(self):
        op = set()
        for mod in self.modules:
            op.update({k for k in mod.outputs if k.startswith('out')})
        return op

    @property
    def module_outputs(self):
        op = set()
        for mod in self.modules:
            op.update(set(mod.outputs))
        return op

    @property
    def file_inputs(self):
        out = []
        for mod in self.modules:
            out += mod.file_inputs

        return out

    def save(self, context={}):
        """
        Find all OutputModule instances and call their save methods with the recipe context

        Parameters
        ----------
        context : dict
            A context dictionary used to substitute and create variable names.

        """
        for mod in self.modules:
            if isinstance(mod, OutputModule):
                mod.save(self.namespace, context)

    def gather_outputs(self, context={}):
        """
        Find all OutputModule instances and call their generate methods with the recipe context

        Parameters
        ----------
        context : dict
            A context dictionary used to substitute and create variable names.

        """

        outputs = []

        for mod in self.modules:
            if isinstance(mod, OutputModule):
                out = mod.generate(self.namespace, context)

                if not out is None:
                    outputs.append(out)

        return outputs

    def loadInput(self, filename, key='input'):
        """Load input data from a file and inject into namespace

        Currently only handles images (anything you can open in dh5view). TODO -
        extend to other types.
        """
        #modify this to allow for different file types - currently only supports images
        from PYME.IO import unifiedIO
        import os
        extension = os.path.splitext(filename)[1]
        if extension in ['.h5r', '.h5', '.hdf']:
            import tables
            from PYME.IO import MetaDataHandler
            from PYME.IO import tabular

            with unifiedIO.local_or_temp_filename(filename) as fn:
                with tables.open_file(fn, mode='r') as h5f:
                    #make sure our hdf file gets closed

                    key_prefix = '' if key == 'input' else key + '_'

                    try:
                        mdh = MetaDataHandler.NestedClassMDHandler(
                            MetaDataHandler.HDFMDHandler(h5f))
                    except tables.FileModeError:  # Occurs if no metadata is found, since we opened the table in read-mode
                        logger.warning(
                            'No metadata found, proceeding with empty metadata'
                        )
                        mdh = MetaDataHandler.NestedClassMDHandler()

                    for t in h5f.list_nodes('/'):
                        # FIXME - The following isinstance tests are not very safe (and badly broken in some cases e.g.
                        # PZF formatted image data, Image data which is not in an EArray, etc ...)
                        # Note that EArray is only used for streaming data!
                        # They should ideally be replaced with more comprehensive tests (potentially based on array or dataset
                        # dimensionality and/or data type) - i.e. duck typing. Our strategy for images in HDF should probably
                        # also be improved / clarified - can we use hdf attributes to hint at the data intent? How do we support
                        # > 3D data?

                        if isinstance(t, tables.VLArray):
                            from PYME.IO.ragged import RaggedVLArray

                            rag = RaggedVLArray(
                                h5f, t.name, copy=True
                            )  #force an in-memory copy so we can close the hdf file properly
                            rag.mdh = mdh

                            self.namespace[key_prefix + t.name] = rag

                        elif isinstance(t, tables.table.Table):
                            #  pipe our table into h5r or hdf source depending on the extension
                            tab = tabular.H5RSource(
                                h5f, t.name
                            ) if extension == '.h5r' else tabular.HDFSource(
                                h5f, t.name)
                            tab.mdh = mdh

                            self.namespace[key_prefix + t.name] = tab

                        elif isinstance(t, tables.EArray):
                            # load using ImageStack._loadh5, which finds metdata
                            im = ImageStack(filename=filename, haveGUI=False)
                            # assume image is the main table in the file and give it the named key
                            self.namespace[key] = im

        elif extension == '.csv':
            logger.error('loading .csv not supported yet')
            raise NotImplementedError
        elif extension in ['.xls', '.xlsx']:
            logger.error('loading .xls not supported yet')
            raise NotImplementedError
        else:
            self.namespace[key] = ImageStack(filename=filename, haveGUI=False)

    @property
    def pipeline_view(self):
        import wx
        if wx.GetApp() is None:
            return None
        else:
            from traitsui.api import View, ListEditor, InstanceEditor, Item
            #v = tu.View(tu.Item('modules', editor=tu.ListEditor(use_notebook=True, view='pipeline_view'), style='custom', show_label=False),
            #            buttons=['OK', 'Cancel'])

            return View(Item('modules',
                             editor=ListEditor(
                                 style='custom',
                                 editor=InstanceEditor(view='pipeline_view'),
                                 mutable=False),
                             style='custom',
                             show_label=False),
                        buttons=['OK', 'Cancel'])

    def to_svg(self):
        from . import recipeLayout
        return recipeLayout.to_svg(self.dependancyGraph())

    def _repr_svg_(self):
        """ Make us look pretty in Jupyter"""
        return self.to_svg()
Exemplo n.º 11
0
class Filter(ModuleBase):
    """Module with one image input and one image output"""
    inputName = Input('input')
    outputName = Output('filtered_image')

    processFramesIndividually = Bool(True)

    def filter(self, image):
        #from PYME.util.shmarray import shmarray
        #import multiprocessing

        if self.processFramesIndividually:
            filt_ims = []
            for chanNum in range(image.data.shape[3]):
                filt_ims.append(
                    np.concatenate([
                        np.atleast_3d(
                            self.applyFilter(
                                image.data[:, :, i,
                                           chanNum].squeeze().astype('f'),
                                chanNum, i, image))
                        for i in range(image.data.shape[2])
                    ], 2))
        else:
            filt_ims = [
                np.atleast_3d(
                    self.applyFilter(
                        image.data[:, :, :, chanNum].squeeze().astype('f'),
                        chanNum, 0, image))
                for chanNum in range(image.data.shape[3])
            ]

        im = ImageStack(filt_ims, titleStub=self.outputName)
        im.mdh.copyEntriesFrom(image.mdh)
        im.mdh['Parent'] = image.filename

        self.completeMetadata(im)

        return im

    def execute(self, namespace):
        namespace[self.outputName] = self.filter(namespace[self.inputName])

    def completeMetadata(self, im):
        pass

    @classmethod
    def dsviewer_plugin_callback(cls, dsviewer, showGUI=True, **kwargs):
        """Implements a callback which allows this module to be used as a plugin for dsviewer.

        Parameters
        ----------

        dsviewer : :class:`PYME.DSView.dsviewer.DSViewFrame` instance
            This is the current :class:`~PYME.DSView.dsviewer.DSViewFrame` instance. The filter will be run with the
            associated ``.image`` as input and display the output in a new window.

        showGUI : bool
            Should we show a GUI to set parameters (generated by calling configure_traits()), or just run with default
            parameters.

        **kwargs : dict
            Optionally, provide default values for parameters. Makes most sense when used with showGUI = False

        """
        from PYME.DSView import ViewIm3D

        mod = cls(inputName='input', outputName='output', **kwargs)
        if (not showGUI) or mod.configure_traits(kind='modal'):
            namespace = {'input': dsviewer.image}
            mod.execute(namespace)

            ViewIm3D(mod['output'],
                     parent=dsviewer,
                     glCanvas=dsviewer.glCanvas)
Exemplo n.º 12
0
class CalculateFRCBase(ModuleBase):
    """
    Base class. Refer to derived classes for docstrings.
    """
    pre_filter = Enum(['Tukey_1/8', None])
    frc_smoothing_func = Enum(['Cubic Spline', 'Sigmoid', None])
    multiprocessing =  Bool(True)
#    plot_graphs = Bool()
    cubic_smoothing  = Float(0.01)
    
    save_path = File()
    
#    output_fft_image_a = Output('FRC_fft_image_a')
#    output_fft_image_b = Output('FRC_fft_image_b')
    output_fft_images_cc = Output('FRC_fft_images_cc')
    output_frc_dict = Output('FRC_dict')
    output_frc_plot = Output('FRC_plot')
    output_frc_raw = Output('FRC_raw')
    
    def execute(self):
        raise Exception("Base class not fully implemented")
    
    def preprocess_images(self, image_pair):
        # pad images to square shape
#        image_pair = self.pad_images_to_equal_dims(image_pair)
        
        dims_length = np.stack([im.shape for im in image_pair], 0)
        assert np.all([np.all(dims_length[:, i] == dims_length[0, i])for i in range(dims_length.shape[1])]), "Images not the same dimension."
        
        # apply filtering to zero near edge of images
        if self.pre_filter == 'Tukey_1/8':
            image_pair = self.filter_images_tukey(image_pair, 1./8)
        elif self.pre_filter == None:
            pass
        else:
            raise Exception()
            
        return image_pair
        
#    def pad_images_to_equal_dims(self, images):
#        dims_length = np.stack([im.shape for im in images], 0)
#        assert np.all([np.all(dims_length[:, i] == dims_length[0, i])for i in xrange(dims_length.shape[1])]), "Images not the same dimension."
#        
#        return images
#        
#        dims_length = dims_length[0, :]        
#        max_dim = dims_length.max()
#        
#        padding = np.empty((dims_length.shape[0], 2), dtype=np.int)
#        for dim in xrange(dims_length.shape[0]):
#            total_padding = max_dim - dims_length[dim]
#            padding[dim] = [total_padding // 2, total_padding - total_padding //2]
#            
#        results = list()
#        for im in images:
#            results.append(np.pad(im, padding, mode='constant', constant_values=0))
#        
#        return results

    def filter_images_tukey(self, images, alpha):
        from scipy.signal import tukey
        
#        window = tukey(images[0].shape[0], alpha=alpha)
#        window_nd = np.prod(np.stack(np.meshgrid(*(window,)*images[0].ndim)), axis=0)
        windows = [tukey(images[0].shape[i], alpha=alpha) for i in range(images[0].ndim)]
        window_nd = np.prod(np.stack(np.meshgrid(*windows, indexing='ij')), axis=0)
        
#        for im in images:
#            im *= window_nd
        
        return [images[0]*window_nd, images[1]*window_nd]

    def calculate_FRC_from_images(self, image_pair, mdh):
        ft_images = list()
        if self.multiprocessing:
            results = list()
            for im in image_pair:
                results.append(self._pool.apply_async(np.fft.fftn, (im,)))            
            for res in results:
                ft_images.append(res.get())
            del results
        else:
            for im in image_pair:
                ft_images.append(np.fft.fftn(im))
        
#        im_fft_freq = np.fft.fftfreq(image_pair[0].shape[0], self._pixel_size_in_nm)
#        im_R = np.sqrt(im_fft_freq[:, None]**2 + im_fft_freq[None, :]**2)
        im_fft_freqs = [np.fft.fftfreq(image_pair[0].shape[i], self._pixel_size_in_nm[i]) for i in range(image_pair[0].ndim)]
        im_R = np.linalg.norm(np.stack(np.meshgrid(*im_fft_freqs, indexing='ij')), axis=0)

        im1_fft_power = np.multiply(ft_images[0], np.conj(ft_images[0]))
        im2_fft_power = np.multiply(ft_images[1], np.conj(ft_images[1]))        
        im12_fft_power = np.multiply(ft_images[0], np.conj(ft_images[1]))                
        
##        fft_ims = ImageStack(data=np.stack([np.fft.fftshift(im1_fft_power),
##                                            np.fft.fftshift(im2_fft_power),
##                                            np.fft.fftshift(im12_fft_power)], axis=-1), mdh=mdh)
##        self._namespace[self.output_fft_images] = fft_ims
#        self._namespace[self.output_fft_image_a] = ImageStack(data=np.fft.fftshift(im1_fft_power), titleStub="ImageA_FFT")
#        self._namespace[self.output_fft_image_b] = ImageStack(data=np.fft.fftshift(im2_fft_power), titleStub="ImageB_FFT")
#        self._namespace[self.output_fft_images_cc] = ImageStack(data=np.fft.fftshift(im12_fft_power), titleStub="ImageA_Image_B_FFT_CC")
        
        try:
            self._namespace[self.output_fft_images_cc] = ImageStack(data=np.stack([np.atleast_3d(np.fft.fftshift(im1_fft_power)),
                   np.atleast_3d(np.fft.fftshift(im2_fft_power)),
                   np.atleast_3d(np.fft.fftshift(im12_fft_power))], 3), titleStub="ImageA_Image_FFT_CC")
            
#            if self.plot_graphs:
#                from PYME.DSView.dsviewer import ViewIm3D, View3D
#    #            ViewIm3D(self._namespace[self.output_fft_image_a])
#    #            ViewIm3D(self._namespace[self.output_fft_image_b])
#                ViewIm3D(self._namespace[self.output_fft_images_cc])
#    #            View3D(np.fft.fftshift(im_R))
            
        except Exception as e:
            print (e)
            
        
        im1_fft_flat_res = CalculateFRCBase.BinData(im_R.flatten(), im1_fft_power.flatten(), statistic='mean', bins=201)
        im2_fft_flat_res = CalculateFRCBase.BinData(im_R.flatten(), im2_fft_power.flatten(), statistic='mean', bins=201)
        im12_fft_flat_res = CalculateFRCBase.BinData(im_R.flatten(), im12_fft_power.flatten(), statistic='mean', bins=201)
        
        corr = np.real(im12_fft_flat_res.statistic) / np.sqrt(np.abs(im1_fft_flat_res.statistic*im2_fft_flat_res.statistic))
        
        smoothed_frc = self.smooth_frc(im12_fft_flat_res.bin_edges[:-1], corr, self.cubic_smoothing)
        
        res, rawdata = self.calculate_threshold(im12_fft_flat_res.bin_edges[:-1], corr, smoothed_frc, im12_fft_flat_res.counts)
        
        return res, rawdata
    
    def smooth_frc(self, freq, corr, cubic_smoothing):
        if self.frc_smoothing_func is None:
            interp_frc = interpolate.interp1d(freq, corr, kind='next', )
            return interp_frc
        
        elif self.frc_smoothing_func == "Sigmoid":
            func = CalculateFRCBase.Sigmoid
            fit_res = optimize.minimize(lambda a, x: np.sum(np.square(func(a[0], a[1], x)-corr)), [1, freq[len(freq)/2]], args=(freq), method='Nelder-Mead')

            return partial(func, n=fit_res.x[0], c=fit_res.x[1])
        elif self.frc_smoothing_func == "Cubic Spline":
            # smoothed so that average deviation loss is less than 0.2% of original. Somewhat arbitrary but probably not totally unreasonable since FRC is bounded 0 to 1.
#            interp_frc = interpolate.UnivariateSpline(freq, corr, k=3, s=len(freq)*(0.002*np.var(corr)))
#            interp_frc = interpolate.UnivariateSpline(freq, corr, k=3, s=(0.05*np.std(corr)))
            interp_frc = interpolate.UnivariateSpline(freq, corr, k=3, s=cubic_smoothing)

            return interp_frc            
    
    def calculate_threshold(self, freq, corr, corr_func, counts):
        res = dict()
        
        fsc_0143 = optimize.minimize(lambda x: np.square(corr_func(x=x)-0.143), freq[np.argmax(corr_func(x=freq)-0.143 < 0)], method='Nelder-Mead')
        res['frc 1/7'] = 1./fsc_0143.x[0]
        
        sigma = 1.0 / np.sqrt(counts*0.5)
        sigma_spl = interpolate.UnivariateSpline(freq, sigma, k=3, s=0)
        fsc_3sigma = optimize.minimize(lambda x: np.square(corr_func(x=x)-3.*sigma_spl(x)), freq[np.argmax(corr_func(x=freq)-3.*sigma_spl(freq) < 0)], method='Nelder-Mead')
        res['frc 3 sigma'] = 1./fsc_3sigma.x[0]        
        
        # van Heel and Schatz, 2005, Fourier shell correlation threshold criteria    
        half_bit = (0.2071 + 1.9102 / np.sqrt(counts)) / (1.2071 + 0.9102 / np.sqrt(counts))
        half_bit_spl = interpolate.UnivariateSpline(freq, half_bit, k=3, s=0)
        fsc_half_bit = optimize.minimize(lambda x: np.square(corr_func(x=x)-half_bit_spl(x)), freq[np.argmax(corr_func(x=freq)-half_bit_spl(freq) < 0)], method='Nelder-Mead')
        res['frc half bit'] = 1./fsc_half_bit.x[0]
        
#        fsc_max = np.max([fsc_0143.x[0], fsc_2sigma.x[0], fsc_3sigma.x[0], fsc_5sigma.x[0], fsc_half_bit.x[0]])
#        axes[1].set_xlim(0, np.min([2*fsc_max, im12_fft_flat_res.bin_edges[-1]]))
        
#        if not self.plot_graphs:
#            ioff()
        
        def plot():
            frc_text = ""
            fig, axes = subplots(1,2,figsize=(10,4))
            axes[0].plot(freq, corr)
            axes[0].plot(freq, corr_func(x=freq))
            
            axes[0].axhline(0.143, ls='--', color='red')
            axes[0].axvline(fsc_0143.x[0], ls='--', color='red', label='1/7')
            frc_text += "\n1/7:  {:.2f} nm".format(1./fsc_0143.x[0])
     
            axes[0].plot(freq, 3*sigma_spl(freq), ls='--', color='pink')
            axes[0].axvline(fsc_3sigma.x[0], ls='--', color='pink', label='3 sigma')
            frc_text += "\n3 sigma:  {:.2f} nm".format(1./fsc_3sigma.x[0])
     
            axes[0].plot(freq, half_bit_spl(freq), ls='--', color='purple')
            axes[0].axvline(fsc_half_bit.x[0], ls='--', color='purple', label='1/2 bit')
    
            frc_text += "\n1/2 bit:  {:.2f} nm".format(1./fsc_half_bit.x[0])
            
            axes[0].legend()            
    #        axes[0].set_ylim(None, 1.1)
            
            x_ticklocs = axes[0].get_xticks()
            axes[0].set_xticklabels(["{:.1f}".format(1./i) for i in x_ticklocs])
            axes[0].set_ylabel("FSC/FRC")
            axes[0].set_xlabel("Resol (nm)")
            
            axes[1].text(0.5, 0.5, frc_text, horizontalalignment='center', verticalalignment='center', transform=axes[1].transAxes)
            axes[1].set_axis_off()
        
#        if self.plot_graphs:
#            fig.show()
#        else:
#            ion()
            
        plot()
        
        self._namespace[self.output_frc_plot] = Plot(plot)
        
        rawdata = {'freq':freq, 'corr':corr, 'smooth':corr_func(x=freq), '1/7':np.ones_like(freq)/7, '3 sigma':3*sigma_spl(freq), '1/2 bit':half_bit_spl(freq)}
        
        return res, rawdata
    
    def save_to_file(self, namespace):
        if self.save_path is not "":
            try:
                np.savez_compressed(self.save_path, raw=namespace[self.output_frc_raw], results=namespace[self.output_frc_dict])
            except Exception as e:
                raise e
    
    @staticmethod
    def BinData(indexes, data, statistic='mean', bins=10):
        # Calculates binned statistics. Supports complex number.
        if statistic == 'mean':
            func = np.mean
        elif statistic == 'sum':
            func = np.sum
        
        class Result(object):
            statistic = None
            bin_edges = None
            counts = None
        
        bins = np.linspace(indexes.min(), indexes.max(), bins)
        binned = np.zeros(len(bins)-1, dtype=data.dtype)
        counts = np.zeros(len(bins)-1, dtype=np.int)
        
        indexes_sort_arg = np.argsort(indexes.flatten())
        indexes_sorted = indexes.flatten()[indexes_sort_arg]
        data_sorted = data.flatten()[indexes_sort_arg]
        edges_indexes = np.searchsorted(indexes_sorted, bins)
        
        for i in range(bins.shape[0]-1):
            values = data_sorted[edges_indexes[i]:edges_indexes[i+1]]
            binned[i] = func(values)
            counts[i] = len(values)
            
        res = Result()
        res.statistic = binned
        res.bin_edges = bins
        res.counts = counts
        return res
    
    @staticmethod
    def Sigmoid(n, c, x):
        res = 1 - 1 / (1 + np.exp(n*(-x+c)))
        return res
Exemplo n.º 13
0
class RGBImageUpload(ImageUpload):
    """
    Create RGB (png) image and upload it to an OMERO server, optionally
    attaching localization files.

    Parameters
    ----------
    input_image : str
        name of image in the recipe namespace to upload
    input_localization_attachments : dict
        maps tabular types (keys) to attachment filenames (values). Tabular's
        will be saved as '.hdf' files and attached to the image
    filePattern : str
        pattern to determine name of image on OMERO server
    omero_dataset : str
        name of OMERO dataset to add the image to. If the dataset does not
        already exist it will be created.
    omero_project : str
        name of OMERO project to link the dataset to. If the project does not
        already exist it will be created.
    zoom : float
        how large to zoom the image
    scaling : str 
        how to scale the intensity - one of 'min-max' or 'percentile'
    scaling_factor: float
        `percentile` scaling only - which percentile to use
    colorblind_friendly : bool
        Use cyan, magenta, and yellow rather than RGB. True, by default.
    
    Notes
    -----
    OMERO server address and user login information must be stored in the user
    PYME config directory under plugins/config/pyme-omero, e.g. 
    /Users/Andrew/.PYME/plugins/config/pyme-omero. The file should be yaml
    formatted with the following keys:
        user
        password
        address
        port [optional]
    
    The project/image/dataset will be owned by the user set in the yaml file.
    """
    filePattern = '{file_stub}.png'
    scaling = Enum(['min-max', 'percentile'])
    scaling_factor = Float(0.95)
    zoom = Int(1)
    colorblind_friendly = Bool(True)

    def _save(self, image, path):
        from PIL import Image
        from PYME.IO.rgb_image import image_to_rgb, image_to_cmy

        if (self.colorblind_friendly and (image.data.shape[3] != 1)):
            im = image_to_cmy(image,
                              zoom=self.zoom,
                              scaling=self.scaling,
                              scaling_factor=self.scaling_factor)
        else:
            im = image_to_rgb(image,
                              zoom=self.zoom,
                              scaling=self.scaling,
                              scaling_factor=self.scaling_factor)

        rgb = Image.fromarray(im, mode='RGB')
        rgb.save(path)
Exemplo n.º 14
0
class LabelRange(Filter):
    """Asigns a unique integer label to each contiguous region in the input mask.
    Throws away all regions which are outside of given number of pixel range.
    Also uses the number of sites from a second input channel to decide if region is retained,
    retaining only those with the number sites in a given range.
    """
    inputSitesLabeled = Input(
        "sites")  # sites and the main input must have the same shape!
    minRegionPixels = Int(10)
    maxRegionPixels = Int(100)
    minSites = Int(4)
    maxSites = Int(6)
    sitesAsMaxima = Bool(False)

    def filter(self, image, imagesites):
        #from PYME.util.shmarray import shmarray
        #import multiprocessing

        if self.processFramesIndividually:
            filt_ims = []
            for chanNum in range(image.data.shape[3]):
                filt_ims.append(
                    np.concatenate([
                        np.atleast_3d(
                            self.applyFilter(
                                image.data[:, :, i,
                                           chanNum].squeeze().astype('f'),
                                imagesites.data[:, :, i,
                                                chanNum].squeeze().astype('f'),
                                chanNum, i, image))
                        for i in range(image.data.shape[2])
                    ], 2))
        else:
            filt_ims = [
                np.atleast_3d(
                    self.applyFilter(
                        image.data[:, :, :, chanNum].squeeze().astype('f'),
                        imagesites.data[:, :, :,
                                        chanNum].squeeze().astype('f'),
                        chanNum, 0, image))
                for chanNum in range(image.data.shape[3])
            ]

        im = ImageStack(filt_ims, titleStub=self.outputName)
        im.mdh.copyEntriesFrom(image.mdh)
        im.mdh['Parent'] = image.filename

        self.completeMetadata(im)

        return im

    def execute(self, namespace):
        namespace[self.outputName] = self.filter(
            namespace[self.inputName], namespace[self.inputSitesLabeled])

    def applyFilter(self, data, sites, chanNum, frNum, im):

        # siteLabels = self.recipe.namespace[self.sitesLabeled]

        mask = data > 0.5
        labs, nlabs = ndimage.label(mask)

        rSize = self.minRegionPixels
        rMax = self.maxRegionPixels

        minSites = self.minSites
        maxSites = self.maxSites

        m2 = 0 * mask
        objs = ndimage.find_objects(labs)
        for i, o in enumerate(objs):
            r = labs[o] == i + 1
            #print r.shape
            area = r.sum()
            if (area >= rSize) and (area <= rMax):
                if self.sitesAsMaxima:
                    nsites = sites[o][r].sum()
                else:
                    nsites = (np.unique(sites[o][r]) > 0).sum(
                    )  # count the unique labels (excluding label 0 which is background)
                if (nsites >= minSites) and (nsites <= maxSites):
                    m2[o] += r

        labs, nlabs = ndimage.label(m2 > 0)

        return labs

    def completeMetadata(self, im):
        im.mdh['Labelling.MinSize'] = self.minRegionPixels
        im.mdh['Labelling.MaxSize'] = self.maxRegionPixels
        im.mdh['Labelling.MinSites'] = self.minSites
        im.mdh['Labelling.MaxSites'] = self.maxSites
Exemplo n.º 15
0
class CalculateFRCFromImages(CalculateFRCBase):
    """
    Take a pair of images and calculates the fourier shell/ring correlation (FSC / FRC).
    
    Inputs
    ------
    input_image_a : ImageStack
        First of two images.
        
    Outputs
    -------
    output_fft_images_cc : ImageStack
        Fast Fourier transform original and cross-correlation images.
    output_frc_dict : dict
        FSC/FRC results.
    output_frc_plot : Plot
        Output plot of the FSC / FRC curve.
    output_frc_raw : dict
        Complete FSC/FRC results.
    
    Parameters
    ----------
    image_b_path : File
        File path of the second of the two images.
    c_channel : int
        Color channel of the images to use.
    image_a_z : int
        Ignored unless flatten_z is True. In which case either select the z plane to use (>=0) or performs a maximum project (<0) for the first image.
    image_b_z : int
        Ignored unless flatten_z is True. In which case either select the z plane to use (>=0) or performs a maximum project (<0) for the second image.
    flatten_z : Bool
        If enabled ignores z information and only performs a FRC.
    pre_filter : string
        Methods to filter the images prior to Fourier transform.
    frc_smoothing_func : string
        Methods to smooth the FSC / FRC curve.
    cubic_smoothing  : float
        Smoothing factor for cubic spline.
    multiprocessing : Bool
        Enables multiprocessing.
    save_path : File
        (Optional) File path to save output
    
    """
    
    input_image_a = Input('input')
#    image_a_dim = Int(2)
#    image_a_index = Int(0)
#    image_b_dim = Int(2)
#    image_b_index = Int(1)
    image_b_path = File(info_text="Filepath of image to compare against. Leave blank to compare against currently opened image.")
    c_channel = Int(0)
    flatten_z = Bool(True)
    image_a_z = Int(-1)
    image_b_z = Int(-1)
    
    def execute(self, namespace):
        self._namespace = namespace
        import multiprocessing
#        from PYME.util import mProfile        
#        mProfile.profileOn(["frc.py"])
        
        if self.multiprocessing:
            proccess_count = np.clip(2, 1, multiprocessing.cpu_count()-1)
            self._pool = multiprocessing.Pool(processes=proccess_count)        
       
#        image_pair = self.generate_image_pair(mapped_pipeline)
#        ims = namespace[self.input_images]
        image_a = namespace[self.input_image_a]

        if len(self.image_b_path.strip()) == 0:
            image_b = image_a
        else:
            image_b = ImageStack(filename=self.image_b_path)
        
        self._pixel_size_in_nm = np.zeros(3, dtype=np.float)
        self._pixel_size_in_nm[0] = image_a.mdh.voxelsize.x
        self._pixel_size_in_nm[1] = image_a.mdh.voxelsize.y
        try:
            self._pixel_size_in_nm[2] = image_a.mdh.voxelsize.z
        except:
            pass
        if image_a.mdh.voxelsize.units == 'um':
            self._pixel_size_in_nm *= 1.E3
            # print(self._pixel_size_in_nm)
        
#        image_indices = [[self.image_a_dim, self.image_a_index], [self.image_b_dim, self.image_b_index]]
#        image_slices = list()
#        for i in xrange(2):
#            slices = [slice(None, None), slice(None, None)]
#            for j in xrange(2, image_indices[i][0]+1):
#                if j == image_indices[i][0]:
#                    slices.append(slice(image_indices[i][1], image_indices[i][1]+1))
#                else:
#                    slices.append(slice(None, None))
#            image_slices.append(slices)
#        
#        image_pair = [ims.data[image_slices[0]].squeeze(), ims.data[image_slices[1]].squeeze()]
        image_a_data = image_a.data[:,:,:,self.c_channel].squeeze()
        image_b_data = image_b.data[:,:,:,self.c_channel].squeeze()
        if self.flatten_z:
            print("2D mode. Slice if z index >= 0 otherwise max projection")
            if self.image_a_z >= 0:
                image_a_data = image_a_data[:,:,self.image_a_z]
            else:
                image_a_data = image_a_data.max(2)
            if self.image_b_z >= 0:
                image_b_data = image_b_data[:,:,self.image_b_z]
            else:
                image_b_data = image_b_data.max(2)
#            print(np.allclose(image_a_data, image_b_data))
        image_pair = [image_a_data, image_b_data]
#        print(image_pair[0].shape)
        
        image_pair = self.preprocess_images(image_pair)            
       
        frc_res, rawdata = self.calculate_FRC_from_images(image_pair, None)
        
        namespace[self.output_frc_dict] = frc_res
        namespace[self.output_frc_raw] = rawdata
        
        if self.multiprocessing:
            self._pool.close()
            self._pool.join()
        
#        mProfile.profileOff()
#        mProfile.report()
        
        self.save_to_file(namespace)
class RCCDriftCorrection(RCCDriftCorrectionBase):
    """
    For localization data.
    
    Performs drift correction using cross-correlation, including redundant RCC from
    Wang et al. Optics Express 2014 22:13 (Bo Huang's RCC algorithm).
    
    Runtime will vary hugely depending on size of dataset and settings.
    
    ``cache_fft`` is necessary for large datasets.
        
    Inputs
    ------
    input_for_correction : Tabular
        Dataset used to calculate drift.
    input_for_mapping : Tabular
        *Deprecated.*   Dataset to correct.
    
    Outputs
    -------
    output_drift : Tuple of arrays
        Drift results.
    output_drift_plot : Plot
        *Deprecated.*   Plot of drift results.
    output_cross_cor : ImageStack
        Cross correlation images if ``debug_cor_file`` is not blank.
    outputName : Tabular
        *Deprecated.*   Drift-corrected dataset.
    
    Parameters
    ----------
    step : Int
        Setting for image construction. Step size between images
    window : Int
        Setting for image construction. Number of frames used per image. Should be equal or larger than step size.
    binsize : Float
        Setting for image construction. Pixel size.
    flatten_z : Bool
        Setting for image construction. Ignore z information if enabled.
    tukey_size : Float
        Setting for image construction. Shape parameter for Tukey filter (``scipy.signal.tukey``).
    cache_fft : File
        Use file as disk cache if provided.
    method : String
        Redundant, mean, or direct cross-correlation.
    shift_max : Float
        Rejection threshold for RCC.
    corr_window : Float
        Size of correlation window. Frames are only compared if within this frame range. N/A for DCC.
    multiprocessing : Float
        Enables multiprocessing.
    debug_cor_file : File
        Enables debugging. Use file as disk cache if provided.
    """
    
    input_for_correction = Input('Localizations')
    input_for_mapping = Input('Localizations')
    # redundant cross-corelation, mean cross-correlation, direction cross-correlation
    step = Int(2500)
    window = Int(2500)
    binsize = Float(30)
    flatten_z = Bool()
    tukey_size = Float(0.25)

    outputName = Output('corrected_localizations')
    
    def calc_corr_drift_from_locs(self, x, y, z, t):

        # bin edges for histogram
        bx = np.arange(x.min(), x.max() + self.binsize + 1, self.binsize)
        by = np.arange(y.min(), y.max() + self.binsize + 1, self.binsize)
        bz = np.arange(z.min(), z.max() + self.binsize + 1, self.binsize)
        
        # pad bin length to odd number so image size is even
        if bx.shape[0] % 2 == 0:
            bx = np.concatenate([bx, [bx[-1] + bx[1] - bx[0]]])
        if by.shape[0] % 2 == 0:
            by = np.concatenate([by, [by[-1] + by[1] - by[0]]])
        if bz.shape[0] > 2 and bz.shape[0] % 2 == 0:
            bz = np.concatenate([bz, [bz[-1] + bz[1] - bz[0]]])
        assert (bx.shape[0] % 2 == 1) and (by.shape[0] % 2 == 1), "Ops. Image not correctly padded to even size."

        # start time of all windows, allow partial window near end of pipeline
        time_values = np.arange(t.min(), t.max() + 1, self.step)
        # 2d array, start and end time of windows
        time_values = np.stack([time_values, np.clip(time_values + self.window, None, t.max())], axis=1)        
        n_steps = time_values.shape[0]
        # center time of center for returning. last window may have different spacing
        time_values_mid = time_values.mean(axis=1)

        if (np.any(np.diff(t) < 0)): # in case pipeline is not sorted for whatever reason
            t_sort_arg = np.argsort(t)
            t = t[t_sort_arg]
            x = x[t_sort_arg]
            y = y[t_sort_arg]
            z = z[t_sort_arg]
            
        time_indexes = np.zeros_like(time_values, dtype=int)
        time_indexes[:, 0] = np.searchsorted(t, time_values[:, 0], side='left')
        time_indexes[:, 1] = np.searchsorted(t, time_values[:, 1]-1, side='right')
        
#        print('time indexes')
#        print(time_values)
#        print(time_values_mid)
#        print(time_indexes)

        # Fourier transformed (and binned) set of images to correlate against
        # one another
        # Crude way of swaping longest axis to the last for optimizing rfft performance.
        # Code changed for this is limited to this method.
        xyz = np.asarray([x, y, z])
        bxyz = np.asarray([bx, by, bz])
        dims_order = np.arange(len(xyz))
        dims_length = np.asarray([len(b) for b in bxyz])
        dims_largest_index = np.argmax(dims_length)
        dims_order[-1], dims_order[dims_largest_index] = dims_order[dims_largest_index], dims_order[-1]
        xyz = xyz[dims_order]
        bxyz = bxyz[dims_order]
        dims_length = dims_length[dims_order]
        
        # use memmap for caching if ft_cache is defined
        if self.cache_fft == "":
            ft_images = np.zeros((n_steps, dims_length[0]-1, dims_length[1]-1, (dims_length[2]-1)//2 + 1, ), dtype=np.complex)
        else:
            ft_images = np.memmap(self.cache_fft, dtype=np.complex, mode='w+', shape=(n_steps, dims_length[0]-1, dims_length[1]-1, (dims_length[2]-1)//2 + 1, ))
        
        print(ft_images.shape)
        print("{:,} bytes".format(ft_images.nbytes))
        
        print("{:.2f} s. About to start heavy lifting.".format(time.time() - self._start_time))
        
        # fill ft_images
        # if multiprocessing, can either use or not caching
        # if not multiprocessing, don't pass filenames for caching, just the memmap array is fine
        if self.multiprocessing:
            dt = ft_images.dtype
            sh = ft_images.shape
            args = [(i, xyz[:,slice(*ti)].T, bxyz, (self.cache_fft, dt, sh, i), self.tukey_size) for i, ti in enumerate(time_indexes)]

            for i, (j, res) in enumerate(self._pool.imap_unordered(calc_fft_from_locs_helper, args)):                
                if self.cache_fft == "":
                    ft_images[j] = res
                 
                if ((i+1) % (n_steps//5) == 0):
                    print("{:.2f} s. Completed calculating {} of {} total ft images.".format(time.time() - self._start_time, i+1, n_steps))

        else:
            # For each window we wish to correlate...
            for i, ti in enumerate(time_indexes):
    
                # .. we generate an image and store ft of image
                t_slice = slice(*ti)
                ft_images[i] = calc_fft_from_locs(xyz[:,t_slice].T, bxyz, filter_size=self.tukey_size)
                
                if ((i+1) % (n_steps//5) == 0):
                    print("{:.2f} s. Completed calculating {} of {} total ft images.".format(time.time() - self._start_time, i+1, n_steps))
        
        print("{:.2f} s. Finished generating ft array.".format(time.time() - self._start_time))
        print("{:,} bytes".format(ft_images.nbytes))
        
        shifts, coefs = self.calc_corr_drift_from_ft_images(ft_images)
        
        # clean up of ft_images, potentially really large array
        if isinstance(ft_images, np.memmap):
            ft_images.flush()
        del ft_images
        
#        print(shifts)
#        print(coefs)
        return time_values_mid, self.binsize * shifts[:, dims_order], coefs

    def _execute(self, namespace):
#        from PYME.util import mProfile
        
#        self._start_time = time.time()
        self.trait_setq(**{"_start_time": time.time()})
        print("Starting drift correction module.")
        
        if self.multiprocessing:
            proccess_count = np.clip(multiprocessing.cpu_count()-1, 1, None)            
            self.trait_setq(**{"_pool": multiprocessing.Pool(processes=proccess_count)})
        
        locs = namespace[self.input_for_correction]

#        mProfile.profileOn(['localisations.py', 'processing.py'])
        drift_res = self.calc_corr_drift_from_locs(locs['x'], locs['y'], locs['z'] * (0 if self.flatten_z else 1), locs['t'])
        t_shift, shifts = self.rcc(self.shift_max,  *drift_res)
        
#        mProfile.profileOff()
#        mProfile.report()

        if self.multiprocessing:
            self._pool.close()
            self._pool.join()
            
        # convert frame-to-frame drift to drift from origin
        shifts = np.cumsum(shifts, 0)

        out = tabular.mappingFilter(namespace[self.input_for_mapping])
        t_out = out['t']
        # cubic interpolate with no smoothing
        dx = interpolate.CubicSpline(t_shift, shifts[:, 0])(t_out)
        dy = interpolate.CubicSpline(t_shift, shifts[:, 1])(t_out)
        dz = interpolate.CubicSpline(t_shift, shifts[:, 2])(t_out)

        if 'dx' in out.keys():
            # getting around oddity with mappingFilter
            # addColumn adds a new column but also keeps the old column
            # __getitem__ returns the new column
            # but mappings usues the old column
            # Wrap with another level of mappingFilter so the new column becomes the 'old column'
            out.addColumn('dx', dx)
            out.addColumn('dy', dy)
            out.addColumn('dz', dz)
            out = tabular.mappingFilter(out)
#            out.mdh = namespace[self.input_localizations].mdh
            out.setMapping('x', 'x + dx')
            out.setMapping('y', 'y + dy')
            out.setMapping('z', 'z + dz')
        else:
            out.addColumn('dx', dx)
            out.addColumn('dy', dy)
            out.addColumn('dz', dz)
            out.setMapping('x', 'x + dx')
            out.setMapping('y', 'y + dy')
            out.setMapping('z', 'z + dz')

        # propagate metadata, if present
        try:
            out.mdh = locs.mdh
        except AttributeError:
            pass

        namespace[self.outputName] = out
        namespace[self.output_drift] = t_shift, shifts
        
        # non essential, only for plotting out drift data
        namespace[self.output_drift_plot] = Plot(partial(generate_drift_plot, t_shift, shifts))
        
        namespace[self.output_cross_cor] = self._cc_image
Exemplo n.º 17
0
class TriangleRenderLayer(EngineLayer):
    """
    Layer for viewing triangle meshes.
    """
    # properties to show in the GUI. Note that we also inherit 'visible' from BaseLayer
    vertexColour = CStr('constant',
                        desc='Name of variable used to colour our points')
    cmap = Enum(*cm.cmapnames,
                default='gist_rainbow',
                desc='Name of colourmap used to colour faces')
    clim = ListFloat(
        [0, 1],
        desc='How our variable should be scaled prior to colour mapping')
    alpha = Float(1.0, desc='Face tranparency')
    method = Enum(*ENGINES.keys(), desc='Method used to display faces')
    normal_mode = Enum(['Per vertex', 'Per face'])
    display_normals = Bool(False)
    normal_scaling = Float(10.0)
    dsname = CStr(
        'output',
        desc=
        'Name of the datasource within the pipeline to use as a source of triangles (should be a TriangularMesh object)'
    )
    _datasource_choices = List()
    _datasource_keys = List()

    def __init__(self,
                 pipeline,
                 method='wireframe',
                 dsname='',
                 context=None,
                 **kwargs):
        EngineLayer.__init__(self, context=context, **kwargs)
        self._pipeline = pipeline
        self.engine = None
        self.cmap = 'gist_rainbow'

        self.x_key = 'x'  # TODO - make these traits?
        self.y_key = 'y'
        self.z_key = 'z'

        self.xn_key = 'xn'
        self.yn_key = 'yn'
        self.zn_key = 'zn'

        self._bbox = None

        # define a signal so that people can be notified when we are updated (currently used to force a redraw when
        # parameters change)
        self.on_update = dispatch.Signal()

        # define responses to changes in various traits
        self.on_trait_change(self._update, 'vertexColour')
        self.on_trait_change(lambda: self.on_update.send(self), 'visible')
        self.on_trait_change(
            self.update,
            'cmap, clim, alpha, dsname, normal_mode, display_normals, normal_scaling'
        )
        self.on_trait_change(self._set_method, 'method')

        # update any of our traits which were passed as command line arguments
        self.set(**kwargs)

        # update datasource and method
        self.dsname = dsname
        if self.method == method:
            #make sure we still call _set_method even if we start with the default method
            self._set_method()
        else:
            self.method = method

        # if we were given a pipeline, connect ourselves to the onRebuild signal so that we can automatically update
        # ourselves
        if not self._pipeline is None:
            try:
                self._pipeline.onRebuild.connect(self.update)
            except AttributeError:
                pass

    @property
    def datasource(self):
        """
        Return the datasource we are connected to (does not go through the pipeline for triangles_mesh).
        """
        try:
            return self._pipeline.get_layer_data(self.dsname)
        except AttributeError:
            try:
                return self._pipeline[self.dsname]
            except AttributeError:
                return None
        #return self.datasource

    @property
    def _ds_class(self):
        # from PYME.experimental import triangle_mesh
        from PYME.experimental import _triangle_mesh as triangle_mesh
        return triangle_mesh.TrianglesBase

    def _set_method(self):
        self.engine = ENGINES[self.method](self._context)
        self.update()

    def _get_cdata(self):
        try:
            cdata = self.datasource[self.vertexColour]
        except (KeyError, TypeError):
            cdata = np.array([0, 1])

        return cdata

    def _update(self, *args, **kwargs):
        #pass
        cdata = self._get_cdata()
        self.clim = [float(np.nanmin(cdata)), float(np.nanmax(cdata))]
        self.update(*args, **kwargs)

    def update(self, *args, **kwargs):
        try:
            self._datasource_choices = [
                k for k, v in self._pipeline.dataSources.items()
                if isinstance(v, self._ds_class)
            ]
        except AttributeError:
            pass

        if not self.datasource is None:
            dks = [
                'constant',
            ]
            if hasattr(self.datasource, 'keys'):
                dks = dks + sorted(self.datasource.keys())
            self._datasource_keys = dks

        if not (self.engine is None or self.datasource is None):
            print('lw update')
            self.update_from_datasource(self.datasource)
            self.on_update.send(self)

    @property
    def bbox(self):
        return self._bbox

    def update_from_datasource(self, ds):
        """
        Pulls vertices/normals from a binary STL file. See PYME.IO.FileUtils.stl for more info. Calls update_data on the input.

        Parameters
        ----------
        ds :
            PYME.experimental.triangular_mesh.TriangularMesh object

        Returns
        -------
        None
        """
        #t = ds.vertices[ds.faces]
        #n = ds.vertex_normals[ds.faces]

        x, y, z = ds.vertices[ds.faces].reshape(-1, 3).T

        if self.normal_mode == 'Per vertex':
            xn, yn, zn = ds.vertex_normals[ds.faces].reshape(-1, 3).T
        else:
            xn, yn, zn = np.repeat(ds.face_normals.T, 3, axis=1)

        if self.vertexColour in ['', 'constant']:
            c = np.ones(len(x))
            clim = [0, 1]
        #elif self.vertexColour == 'vertex_index':
        #    c = np.arange(0, len(x))
        else:
            c = ds[self.vertexColour][ds.faces].ravel()
            clim = self.clim

        cmap = getattr(cm, self.cmap)
        alpha = float(self.alpha)

        # Do we have coordinates? Concatenate into vertices.
        if x is not None and y is not None and z is not None:
            vertices = np.vstack((x.ravel(), y.ravel(), z.ravel()))
            self._vertices = vertices.T.ravel().reshape(len(x.ravel()), 3)

            if not xn is None:
                self._normals = np.vstack(
                    (xn.ravel(), yn.ravel(),
                     zn.ravel())).T.ravel().reshape(len(x.ravel()), 3)
            else:
                self._normals = -0.69 * np.ones(self._vertices.shape)

            self._bbox = np.array(
                [x.min(), y.min(),
                 z.min(), x.max(),
                 y.max(), z.max()])
        else:
            self._bbox = None

        # TODO: This temporarily sets all triangles to the color red. User should be able to select color.
        if c is None:
            c = np.ones(self._vertices.shape[0]) * 255  # vector of pink

        if clim is not None and c is not None and cmap is not None:
            cs_ = ((c - clim[0]) / (clim[1] - clim[0]))
            cs = cmap(cs_)

            if self.method in ['flat', 'tessel']:
                alpha = cs_ * alpha

            cs[:, 3] = alpha

            if self.method == 'tessel':
                cs = np.power(cs, 0.333)

            self._colors = cs.ravel().reshape(len(c), 4)
        else:
            # cs = None
            if not self._vertices is None:
                self._colors = np.ones((self._vertices.shape[0], 4), 'f')

        self._alpha = alpha
        self._color_map = cmap
        self._color_limit = clim

    def get_vertices(self):
        return self._vertices

    def get_normals(self):
        return self._normals

    def get_colors(self):
        return self._colors

    def get_color_map(self):
        return self._color_map

    @property
    def colour_map(self):
        return self._color_map

    def get_color_limit(self):
        return self._color_limit

    @property
    def default_view(self):
        from traitsui.api import View, Item, Group, InstanceEditor, EnumEditor
        from PYME.ui.custom_traits_editors import HistLimitsEditor, CBEditor

        return View([
            Group([
                Item('dsname',
                     label='Data',
                     editor=EnumEditor(name='_datasource_choices'),
                     visible_when='_datasource_choices')
            ]),
            Item('method'),
            Item('normal_mode', visible_when='method=="shaded"'),
            Item('vertexColour',
                 editor=EnumEditor(name='_datasource_keys'),
                 label='Colour'),
            Group([
                Item('clim',
                     editor=HistLimitsEditor(data=self._get_cdata),
                     show_label=False),
            ],
                  visible_when='vertexColour != "constant"'),
            Group([
                Item('cmap', label='LUT'),
                Item('alpha',
                     visible_when='method in ["flat", "tessel", "shaded"]')
            ])
        ], )
        # buttons=['OK', 'Cancel'])

    def default_traits_view(self):
        return self.default_view
Exemplo n.º 18
0
class CalculateFRCFromLocs(CalculateFRCBase):
    """
    Generates a pair of images from localization data and calculates the fourier shell/ring correlation (FSC / FRC).
    
    Inputs
    ------
    inputName : TabularBase
        Localization data.
        
    Outputs
    -------
    outputName : TabularBase
        Localization data labeled with how it was divided (FRC_group).
    output_images : ImageStack
        Pair of 2D or 3D histogram rendered images.
    output_fft_images_cc : ImageStack
        Fast Fourier transform original and cross-correlation images.
    output_frc_dict : dict
        FSC/FRC results.
    output_frc_plot : Plot
        Output plot of the FSC / FRC curve.
    output_frc_raw : dict
        Complete FSC/FRC results.
    
    Parameters
    ----------
    split_method : string
        Different methods of dividing data into halves.
    pixel_size_in_nm : float
        Pixel size used for rendering the images.
    flatten_z : Bool
        If enabled ignores z information and only performs a FRC.
    pre_filter : string
        Methods to filter the images prior to Fourier transform.
    frc_smoothing_func : string
        Methods to smooth the FSC / FRC curve.
    cubic_smoothing  : float
        Smoothing factor for cubic spline.
    multiprocessing : Bool
        Enables multiprocessing.
    save_path : File
        (Optional) File path to save output
    
    """
    inputName = Input('Localizations')
    split_method = Enum(['halves_random', 'halves_time', 'halves_100_time_chunk', 'halves_10_time_chunk', 'fixed_time', 'fixed_10_time_chunk'])    
    pixel_size_in_nm = Float(5)
#    pre_filter = Enum(['Tukey_1/8', None])
#    frc_smoothing_func = Enum(['Cubic Spline', 'Sigmoid', None])
#    plot_graphs = Bool()
#    multiprocessing =  Bool(True)
    flatten_z = Bool(True)
    
    outputName = Output('FRC_ds')
    output_images = Output('FRC_images')
#    output_fft_images = Output('FRC_fft_images')
#    output_frc_dict = Output('FRC_dict')
#    output_frc_plot = Output('FRC_plot')
    
    def execute(self, namespace):
        self._namespace = namespace
        import multiprocessing
#        from PYME.util import mProfile        
#        mProfile.profileOn(["frc.py"])
        
        if self.multiprocessing:
            proccess_count = np.clip(2, 1, multiprocessing.cpu_count()-1)
            self._pool = multiprocessing.Pool(processes=proccess_count)
        
        pipeline = namespace[self.inputName]
        mapped_pipeline = tabular.mappingFilter(pipeline)
        self._pixel_size_in_nm = self.pixel_size_in_nm * np.ones(3, dtype=np.float)
        
        image_pair = self.generate_image_pair(mapped_pipeline)
        
        image_pair = self.preprocess_images(image_pair)
            
        # Should use DensityMapping recipe eventually when it is ready.
        mdh = MetaDataHandler.NestedClassMDHandler()
        mdh['Rendering.Method'] = "np.histogramdd"
        if 'imageID' in pipeline.mdh.getEntryNames():
            mdh['Rendering.SourceImageID'] = pipeline.mdh['imageID']
        try:
            mdh['Rendering.SourceFilename'] = pipeline.resultsSource.h5f.filename
        except:
            pass        
        mdh.Source = MetaDataHandler.NestedClassMDHandler(pipeline.mdh)        
        mdh['Rendering.NEventsRendered'] = [image_pair[0].sum(), image_pair[1].sum()]
        mdh['voxelsize.units'] = 'um'
        mdh['voxelsize.x'] = self.pixel_size_in_nm * 1E-3
        mdh['voxelsize.y'] = self.pixel_size_in_nm * 1E-3
        
        ims = ImageStack(data=np.stack(image_pair, axis=-1), mdh=mdh)
        namespace[self.output_images] = ims
        
#        if self.plot_graphs:
#            from PYME.DSView.dsviewer import ViewIm3D
#            ViewIm3D(ims)
        
        frc_res, rawdata = self.calculate_FRC_from_images(image_pair, pipeline.mdh)
        
#        smoothed_frc = self.SmoothFRC(frc_freq, frc_corr)
#        
#        self.CalculateThreshold(frc_freq, frc_corr, smoothed_frc)
        
        namespace[self.output_frc_dict] = frc_res
        namespace[self.output_frc_raw] = rawdata
        
        if self.multiprocessing:
            self._pool.close()
            self._pool.join()
        
#        mProfile.profileOff()
#        mProfile.report()
        
        self.save_to_file(namespace)
        
    def generate_image_pair(self, mapped_pipeline):
        # Split localizations into 2 sets
        mask = np.zeros(mapped_pipeline['t'].shape, dtype=np.bool)        
        if self.split_method == 'halves_time':
            sort_arg = np.argsort(mapped_pipeline['t'])
            mask[sort_arg[:len(sort_arg)//2]] = 1
        elif self.split_method == 'halves_random':
            mask[:len(mask)//2] = 1
            np.random.shuffle(mask)            
        elif self.split_method == 'halves_100_time_chunk':
            sort_arg = np.argsort(mapped_pipeline['t'])
            chunksize = mask.shape[0] / 100.0
            for i in range(50):
                mask[sort_arg[int(np.round(i*2*chunksize)):int(np.round((i*2+1)*chunksize))]] = 1
        elif self.split_method == 'halves_10_time_chunk':
            sort_arg = np.argsort(mapped_pipeline['t'])
            chunksize = mask.shape[0] * 0.1
            for i in range(5):
                mask[sort_arg[int(np.round(i*2*chunksize)):int(np.round((i*2+1)*chunksize))]] = 1
        elif self.split_method == 'fixed_time':
            time_cutoff = (mapped_pipeline['t'].ptp() + 1) // 2 + mapped_pipeline['t'].min()
            mask[mapped_pipeline['t'] < time_cutoff] = 1
        elif self.split_method == 'fixed_10_time_chunk':
            time_cutoffs = np.linspace(mapped_pipeline['t'].min(), mapped_pipeline['t'].max(), 11, dtype=np.float)
            for i in range((time_cutoffs.shape[0]-1)//2):
                mask[(mapped_pipeline['t'] > time_cutoffs[i*2]) & (mapped_pipeline['t'] < time_cutoffs[i*2+1])] = 1
        
        if self.split_method.startswith('halves_'):
            assert np.abs(mask.sum() - (~mask).sum()) <= 1, "datasets uneven, {} vs {}".format(mask.sum(), (~mask).sum())
        else:
            print("variable counts between images, {} vs {}".format(mask.sum(), (~mask).sum()))
        
        mapped_pipeline.addColumn("FRC_group", mask)
        self._namespace[self.outputName] = mapped_pipeline        
        
        dims = ['x', 'y', 'z']
        if self.flatten_z:
            dims.remove('z')
            
        # Simple hist2d binning
        bins = list()
        data = list()
        for i, dim in enumerate(dims):
            bins.append(np.arange(np.floor(mapped_pipeline[dim].min()).astype(int), np.ceil(mapped_pipeline[dim].max()).astype(int) + self.pixel_size_in_nm, self.pixel_size_in_nm))
            
            data.append(mapped_pipeline[dim])
        
        data = np.stack(data, axis=1)
        # print(bins)
        
        if self.multiprocessing:
            results = list()
            results.append(self._pool.apply_async(np.histogramdd, (data[mask==0],), kwds= {"bins":bins}))
            results.append(self._pool.apply_async(np.histogramdd, (data[mask==1],), kwds= {"bins":bins}))
            
            image_a = results[0].get()[0]
            image_b = results[1].get()[0]                
        else:
            image_a = np.histogramdd(data[mask==0], bins=bins)[0]
            image_b = np.histogramdd(data[mask==1], bins=bins)[0]
            
        return image_a, image_b
Exemplo n.º 19
0
class EnsembleFitROIs(
        ModuleBase):  # Note that this should probably be moved somewhere else
    inputName = Input('ROIs')

    fit_type = CStr('LorentzianConvolvedSolidSphere_ensemblePSF')
    ensemble_parameter_guess = Float(50.)
    hold_ensemble_parameter_constant = Bool(False)

    outputName = Output('fit_results')

    def execute(self, namespace):

        inp = namespace[self.inputName]

        # generate RegionHandler from tables
        handler = RegionHandler()
        handler._load_from_list(inp)

        fit_class = region_fitters.fitters[self.fit_type]
        fitter = fit_class(handler)

        if self.hold_ensemble_parameter_constant:
            fitter.fit_profiles(self.ensemble_parameter_guess)
        else:
            fitter.ensemble_fit(self.ensemble_parameter_guess)

        res = tabular.RecArraySource(fitter.results)

        # propagate metadata, if present
        res.mdh = MetaDataHandler.NestedClassMDHandler(
            getattr(inp, 'mdh', None))

        res.mdh['EnsembleFitROIs.FitType'] = self.fit_type
        res.mdh[
            'EnsembleFitROIs.EnsembleParameterGuess'] = self.ensemble_parameter_guess
        res.mdh[
            'EnsembleFitROIs.HoldEnsembleParamConstant'] = self.hold_ensemble_parameter_constant

        namespace[self.outputName] = res

    @property
    def _fitter_choices(self):
        return list(region_fitters.ensemble_fitters.keys())  #FIXME???

    @property
    def default_view(self):
        from traitsui.api import View, Item
        from PYME.ui.custom_traits_editors import CBEditor

        return View(Item('inputName',
                         editor=CBEditor(choices=self._namespace_keys)),
                    Item('_'),
                    Item('fit_type',
                         editor=CBEditor(choices=self._fitter_choices)),
                    Item('_'),
                    Item('ensemble_parameter_guess'),
                    Item('_'),
                    Item('hold_ensemble_parameter_constant'),
                    Item('_'),
                    Item('outputName'),
                    buttons=['OK'])

    @property
    def pipeline_view(self):
        from traitsui.api import View, Item
        from PYME.ui.custom_traits_editors import CBEditor

        return View(
            Item('fit_type', editor=CBEditor(choices=self._fitter_choices)),
            Item('_'),
            Item('ensemble_parameter_guess'),
            Item('_'),
            Item('hold_ensemble_parameter_constant'),
        )

    @property
    def dsview_view(self):
        from traitsui.api import View, Item
        from PYME.ui.custom_traits_editors import CBEditor

        return View(Item('fit_type',
                         editor=CBEditor(choices=self._fitter_choices)),
                    Item('_'),
                    Item('ensemble_parameter_guess'),
                    Item('_'),
                    Item('hold_ensemble_parameter_constant'),
                    buttons=['OK'])
Exemplo n.º 20
0
class PointCloudRenderLayer(EngineLayer):
    """
    A layer for viewing point-cloud data, using one of 3 engines (indicated above)
    
    """
    
    # properties to show in the GUI. Note that we also inherit 'visible' from BaseLayer
    vertexColour = CStr('', desc='Name of variable used to colour our points')
    point_size = Float(30.0, desc='Rendered size of the points in nm')
    cmap = Enum(*cm.cmapnames, default='gist_rainbow', desc='Name of colourmap used to colour points')
    clim = ListFloat([0, 1], desc='How our variable should be scaled prior to colour mapping')
    alpha = Float(1.0, desc='Point tranparency')
    method = Enum(*ENGINES.keys(), desc='Method used to display points')
    display_normals = Bool(False)
    normal_scaling = Float(10.0)
    dsname = CStr('output', desc='Name of the datasource within the pipeline to use as a source of points')
    _datasource_keys = List()
    _datasource_choices = List()

    def __init__(self, pipeline, method='points', dsname='', context=None, **kwargs):
        EngineLayer.__init__(self, context=context, **kwargs)
        self._pipeline = pipeline
        self.engine = None
        self.cmap = 'gist_rainbow'

        self.x_key = 'x' #TODO - make these traits?
        self.y_key = 'y'
        self.z_key = 'z'
        
        self.xn_key = 'xn'
        self.yn_key = 'yn'
        self.zn_key = 'zn'

        self._bbox = None
    
        # define a signal so that people can be notified when we are updated (currently used to force a redraw when
        # parameters change)
        self.on_update = dispatch.Signal()

        # define responses to changes in various traits
        self.on_trait_change(self._update, 'vertexColour')
        self.on_trait_change(lambda: self.on_update.send(self), 'visible')
        self.on_trait_change(self.update, 'cmap, clim, alpha, dsname, point_size')
        self.on_trait_change(self._set_method, 'method')

        # update any of our traits which were passed as command line arguments
        self.set(**kwargs)
        
        # update datasource name and method
        #logger.debug('Setting dsname and method')
        self.dsname = dsname
        self.method = method
        
        self._set_method()
        
        # if we were given a pipeline, connect ourselves to the onRebuild signal so that we can automatically update
        # ourselves
        if not self._pipeline is None:
            self._pipeline.onRebuild.connect(self.update)


    @property
    def datasource(self):
        """
        Return the datasource we are connected to (through our dsname property).
        """
        return self._pipeline.get_layer_data(self.dsname)

    def _set_method(self):
        #logger.debug('Setting layer method to %s' % self.method)
        self.engine = ENGINES[self.method](self._context)
        self.update()

    def _get_cdata(self):
        try:
            cdata = self.datasource[self.vertexColour]
        except KeyError:
            cdata = np.array([0, 1])
    
        return cdata

    def _update(self, *args, **kwargs):
        cdata = self._get_cdata()
        self.clim = [float(np.nanmin(cdata)), float(np.nanmax(cdata))+1e-9]
        #self.update(*args, **kwargs)

    def update(self, *args, **kwargs):
        print('lw update')
        self._datasource_choices = self._pipeline.layer_data_source_names
        if not self.datasource is None:
            self._datasource_keys = sorted(self.datasource.keys())
            
        if not (self.engine is None or self.datasource is None):
            self.update_from_datasource(self.datasource)
            self.on_update.send(self)
            


    @property
    def bbox(self):
        return self._bbox
    
    
    def update_from_datasource(self, ds):
        x, y = ds[self.x_key], ds[self.y_key]
        
        if not self.z_key is None:
            try:
                z = ds[self.z_key]
            except KeyError:
                z = 0*x
        else:
            z = 0 * x
        
        if not self.vertexColour == '':
            c = ds[self.vertexColour]
        else:
            c = 0*x
            
        if self.xn_key in ds.keys():
            xn, yn, zn = ds[self.xn_key], ds[self.yn_key], ds[self.zn_key]
            self.update_data(x, y, z, c, cmap=getattr(cm, self.cmap), clim=self.clim, alpha=self.alpha, xn=xn, yn=yn, zn=zn)
        else:
            self.update_data(x, y, z, c, cmap=getattr(cm, self.cmap), clim=self.clim, alpha=self.alpha)
    
    
    def update_data(self, x=None, y=None, z=None, colors=None, cmap=None, clim=None, alpha=1.0, xn=None, yn=None, zn=None):
        self._vertices = None
        self._normals = None
        self._colors = None
        self._color_map = None
        self._color_limit = 0
        self._alpha = 0
        
        if x is not None and y is not None and z is not None and len(x) > 0:
            vertices = np.vstack((x.ravel(), y.ravel(), z.ravel()))
            vertices = vertices.T.ravel().reshape(len(x.ravel()), 3)
            
            if not xn is None:
                normals = np.vstack((xn.ravel(), yn.ravel(), zn.ravel())).T.ravel().reshape(len(x.ravel()), 3)
            else:
                normals = -0.69 * np.ones(vertices.shape)
            
            self._bbox = np.array([x.min(), y.min(), z.min(), x.max(), y.max(), z.max()])
        else:
            vertices = None
            normals = None
            self._bbox = None
        
        if clim is not None and colors is not None and clim is not None:
            cs_ = ((colors - clim[0]) / (clim[1] - clim[0]))
            cs = cmap(cs_)
            cs[:, 3] = alpha
            
            cs = cs.ravel().reshape(len(colors), 4)
        else:
            if not vertices is None:
                cs = np.ones((vertices.shape[0], 4), 'f')
            else:
                cs = None

        self.set_values(vertices, normals, cs, cmap, clim, alpha)

    def set_values(self, vertices=None, normals=None, colors=None, color_map=None, color_limit=None, alpha=None):
        if vertices is not None:
            self._vertices = vertices
        if normals is not None:
            self._normals = normals
        if color_map is not None:
            self._color_map = color_map
        if colors is not None:
            self._colors = colors
        if color_limit is not None:
            self._color_limit = color_limit
        if alpha is not None:
            self._alpha = alpha
    
    def get_vertices(self):
        return self._vertices
    
    def get_normals(self):
        return self._normals
    
    def get_colors(self):
        return self._colors
    
    def get_color_map(self):
        return self._color_map
    
    @property
    def colour_map(self):
        return self._color_map
    
    def get_color_limit(self):
        return self._color_limit

    @property
    def default_view(self):
        from traitsui.api import View, Item, Group, InstanceEditor, EnumEditor, TextEditor
        from PYME.ui.custom_traits_editors import HistLimitsEditor, CBEditor
    
        return View([Group([Item('dsname', label='Data', editor=EnumEditor(name='_datasource_choices')), ]),
                     Item('method'),
                     Item('vertexColour', editor=EnumEditor(name='_datasource_keys'), label='Colour', visible_when='cmap not in ["R", "G", "B", "C", "M","Y", "K"]'),
                     Group([Item('clim', editor=HistLimitsEditor(data=self._get_cdata, update_signal=self.on_update), show_label=False), ], visible_when='cmap not in ["R", "G", "B", "C", "M","Y", "K"]'),
                     Group(Item('cmap', label='LUT'),
                           Item('alpha', visible_when="method in ['pointsprites', 'transparent_points']", editor=TextEditor(auto_set=False, enter_set=True, evaluate=float)),
                           Item('point_size', label=u'Point\u00A0size', editor=TextEditor(auto_set=False, enter_set=True, evaluate=float)))])
        #buttons=['OK', 'Cancel'])

    def default_traits_view(self):
        return self.default_view
Exemplo n.º 21
0
class LUTOverlayLayer(OverlayLayer):
    """
    This OverlayLayer produces a bar that indicates the given color map.
    """

    show_bounds = Bool(False)

    def __init__(self, offset=None, **kwargs):
        """

        Parameters
        ----------

        offset      offset of the canvas origin where it should be drawn.
                    Currently only offset[0] is used
        """
        if not offset:
            offset = [10, 10]

        OverlayLayer.__init__(self, offset, **kwargs)

        self.set_offset(offset)

        self._lut_width_px = 10.0
        self._border_colour = [.5, .5, 0]
        self.set_shader_program(DefaultShaderProgram)

        self._labels = {}

    def _get_label(self, layer):
        from . import text
        try:
            return self._labels[layer]
        except KeyError:
            self._labels[layer] = [text.Text(), text.Text()]

            return self._labels[layer]

    def render(self, gl_canvas):
        if not self.visible:
            return

        self._clear_shader_clipping()
        labels = []

        with self.shader_program:
            glDisable(GL_DEPTH_TEST)
            glDisable(GL_LIGHTING)
            glDisable(GL_BLEND)
            #view_size_x = gl_canvas.xmax - gl_canvas.xmin
            #view_size_y = gl_canvas.ymax - gl_canvas.ymin

            view_size_x, view_size_y = gl_canvas.Size

            # upper right y
            lb_ur_y = .1 * view_size_y
            # lower right y
            lb_lr_y = 0.9 * view_size_y

            lb_len = lb_lr_y - lb_ur_y

            lb_width = self._lut_width_px  #* view_size_x / gl_canvas.Size[0]

            visible_layers = [
                l for l in gl_canvas.layers if
                (getattr(l, 'visible', True) and getattr(l, 'show_lut', True))
            ]

            for j, l in enumerate(visible_layers):
                cmap = l.colour_map

                # upper right x
                lb_ur_x = view_size_x - self.get_offset(
                )[0] - j * 1.5 * lb_width

                # upper left x
                lb_ul_x = lb_ur_x - lb_width

                #print(lb_ur_x, lb_ur_y, lb_lr_y, lb_ul_x, lb_width, lb_len, view_size_x, view_size_y)

                glBegin(GL_QUAD_STRIP)

                for i in numpy.arange(0, 1.01, .01):
                    glColor3fv(cmap(i)[:3])
                    glVertex2f(lb_ul_x, lb_ur_y + (1. - i) * lb_len)
                    glVertex2f(lb_ur_x, lb_ur_y + (1. - i) * lb_len)

                glEnd()

                glBegin(GL_LINE_LOOP)
                glColor3fv(self._border_colour)
                glVertex2f(lb_ul_x, lb_lr_y)
                glVertex2f(lb_ur_x, lb_lr_y)
                glVertex2f(lb_ur_x, lb_ur_y)
                glVertex2f(lb_ul_x, lb_ur_y)
                glEnd()

                if hasattr(l, 'clim') and self.show_bounds:
                    tl, tu = self._get_label(l)
                    cl, cu = l.clim
                    tu.text = '%.3G' % cu
                    tl.text = '%.3G' % cl

                    xc = lb_ur_x - 0.5 * lb_width

                    tu.pos = (xc - tu._w / 2, lb_ur_y - tu._h)
                    tl.pos = (xc - tl._w / 2, lb_lr_y)

                    labels.extend([tl, tu])

        for l in labels:
            l.render(gl_canvas)
Exemplo n.º 22
0
class FiducialTrack(ModuleBase):
    """
    Extract average fiducial track from input pipeline

    Parameters
    ----------

        radiusMultiplier: this number is multiplied with error_x to obtain search radius for clustering
        timeWindow: the window along the time dimension used for clustering
        filterScale: the size of the filter kernel used to smooth the resulting average fiducial track
        filterMethod: enumrated choice of filter methods for smoothing operation (Gaussian, Median or Uniform kernel)

    Notes
    -----

    Output is a new pipeline with added fiducial_x, fiducial_y columns

    """
    import PYMEcs.Analysis.trackFiducials as tfs
    inputName = Input('filtered')

    radiusMultiplier = Float(5.0)
    timeWindow = Int(25)
    filterScale = Float(11)
    filterMethod = Enum(tfs.FILTER_FUNCS.keys())
    clumpMinSize = Int(50)
    singleFiducial = Bool(True)

    outputName = Output('fiducialAdded')

    def execute(self, namespace):
        import PYMEcs.Analysis.trackFiducials as tfs

        inp = namespace[self.inputName]
        mapped = tabular.mappingFilter(inp)

        if self.singleFiducial:
            # if all data is from a single fiducial we do not need to align
            # we then avoid problems with incomplete tracks giving rise to offsets between
            # fiducial track fragments
            align = False
        else:
            align = True

        t, x, y, z, isFiducial = tfs.extractTrajectoriesClump(
            inp,
            clumpRadiusVar='error_x',
            clumpRadiusMultiplier=self.radiusMultiplier,
            timeWindow=self.timeWindow,
            clumpMinSize=self.clumpMinSize,
            align=align)
        rawtracks = (t, x, y, z)
        tracks = tfs.AverageTrack(inp,
                                  rawtracks,
                                  filter=self.filterMethod,
                                  filterScale=self.filterScale,
                                  align=align)

        # add tracks for all calculated dims to output
        for dim in tracks.keys():
            mapped.addColumn('fiducial_%s' % dim, tracks[dim])
        mapped.addColumn('isFiducial', isFiducial)

        # propogate metadata, if present
        try:
            mapped.mdh = inp.mdh
        except AttributeError:
            pass

        namespace[self.outputName] = mapped

    @property
    def hide_in_overview(self):
        return ['columns']
Exemplo n.º 23
0
class EnsembleFitProfiles(ModuleBase):
    inputName = Input('line_profiles')

    fit_type = CStr(list(profile_fitters.ensemble_fitters.keys())[0])
    ensemble_parameter_guess = Float(50.)
    hold_ensemble_parameter_constant = Bool(False)

    outputName = Output('fit_results')

    def execute(self, namespace):

        inp = namespace[self.inputName]

        # generate LineProfileHandler from tables
        handler = LineProfileHandler()
        handler._load_profiles_from_list(inp)

        fit_class = profile_fitters.ensemble_fitters[self.fit_type]
        self.fitter = fit_class(handler)

        if self.hold_ensemble_parameter_constant:
            self.fitter.fit_profiles(self.ensemble_parameter_guess)
        else:
            self.fitter.ensemble_fit(self.ensemble_parameter_guess)

        res = tabular.RecArraySource(self.fitter.results)

        # propagate metadata, if present
        res.mdh = MetaDataHandler.NestedClassMDHandler(
            getattr(inp, 'mdh', None))

        res.mdh['EnsembleFitProfiles.FitType'] = self.fit_type
        res.mdh[
            'EnsembleFitProfiles.EnsembleParameterGuess'] = self.ensemble_parameter_guess
        res.mdh[
            'EnsembleFitProfiles.HoldEnsembleParamConstant'] = self.hold_ensemble_parameter_constant

        namespace[self.outputName] = res

    @property
    def _fitter_choices(self):
        return list(profile_fitters.ensemble_fitters.keys())

    @property
    def default_view(self):
        from traitsui.api import View, Item
        from PYME.ui.custom_traits_editors import CBEditor

        return View(Item('inputName',
                         editor=CBEditor(choices=self._namespace_keys)),
                    Item('_'),
                    Item('fit_type',
                         editor=CBEditor(choices=self._fitter_choices)),
                    Item('_'),
                    Item('ensemble_parameter_guess'),
                    Item('_'),
                    Item('hold_ensemble_parameter_constant'),
                    Item('_'),
                    Item('outputName'),
                    buttons=['OK'])

    @property
    def pipeline_view(self):
        from traitsui.api import View, Item
        from PYME.ui.custom_traits_editors import CBEditor

        return View(
            Item('fit_type', editor=CBEditor(choices=self._fitter_choices)),
            Item('_'),
            Item('ensemble_parameter_guess'),
            Item('_'),
            Item('hold_ensemble_parameter_constant'),
        )

    @property
    def dsview_view(self):
        from traitsui.api import View, Item
        from PYME.ui.custom_traits_editors import CBEditor

        return View(Item('fit_type',
                         editor=CBEditor(choices=self._fitter_choices)),
                    Item('_'),
                    Item('ensemble_parameter_guess'),
                    Item('_'),
                    Item('hold_ensemble_parameter_constant'),
                    buttons=['OK'])
class AlignPSF(ModuleBase):
    """
        Align PSF stacks by redundant cross correlation.
    """
    
    inputName = Input('psf_cropped')
    normalize_z = Bool(True)
    tukey = Float(0.50)
    rcc_tolerance = Float(5.0)
    z_crop_half_roi = Int(15)
    peak_detect = Enum(['Gaussian', 'RBF'])
    debug = Bool(False)
    output_cross_corr_images = Output('cross_cor_img')
    output_cross_corr_images_fitted = Output('cross_cor_img_fitted')
    output_images = Output('psf_aligned')
    
    def execute(self, namespace):
        self._namespace = namespace
        ims = namespace[self.inputName]
        
        # X, Y, Z, 'C'
        psf_stack = ims.data[:,:,:,:]
        
        z_slice = slice(psf_stack.shape[2]//2-self.z_crop_half_roi, psf_stack.shape[2]//2+self.z_crop_half_roi+1)
        cleaned_psf_stack = self.normalize_images(psf_stack[:,:,z_slice,:])
            
        if self.tukey > 0:
            masks = [signal.tukey(dim_len, self.tukey) for dim_len in cleaned_psf_stack.shape[:3]]            
            masks = np.product(np.meshgrid(*masks, indexing='ij'), axis=0)            
            cleaned_psf_stack *= masks[:,:,:,None]
            
        drifts = self.calculate_shifts(cleaned_psf_stack, self.rcc_tolerance * 1E3 / ims.mdh['voxelsize.x'])
#        print drifts        
        
        namespace[self.output_images] = ImageStack(self.shift_images(cleaned_psf_stack if self.debug else psf_stack, drifts), mdh=ims.mdh)
        
    def normalize_images(self, psf_stack):
        # in case it is already bg subtracted
        cleaned_psf_stack = np.clip(psf_stack, 0, None)
        
        # substact bg per stack
        cleaned_psf_stack -= cleaned_psf_stack.min(axis=(0,1,2), keepdims=True)
        
        if self.normalize_z:
            # normalize intensity per plane
            cleaned_psf_stack /= cleaned_psf_stack.max(axis=(0,1), keepdims=True) / 1.05
        else:
            # normalize intensity per psf stack
            cleaned_psf_stack /= cleaned_psf_stack.max(axis=(0,1,2), keepdims=True) / 1.05
            
        cleaned_psf_stack -= 0.05
        np.clip(cleaned_psf_stack, 0, None, cleaned_psf_stack)
        
        return cleaned_psf_stack

    def calculate_shifts(self, psf_stack, drift_tolerance):
        n_steps = psf_stack.shape[3]
        coefs_size = n_steps * (n_steps-1) / 2
        coefs = np.zeros((coefs_size, n_steps-1))
        shifts = np.zeros((coefs_size, 3))
        
        output_cross_corr_images = np.zeros((psf_stack.shape[0], psf_stack.shape[1], psf_stack.shape[2], coefs_size), dtype=np.float)
        output_cross_corr_images_fitted = np.zeros((psf_stack.shape[0], psf_stack.shape[1], psf_stack.shape[2], coefs_size), dtype=np.float)
        
        counter = 0
        for i in np.arange(0, n_steps - 1):
            for j in np.arange(i+1, n_steps):
                coefs[counter, i:j] = 1
                
                print "compare {} to {}".format(i, j)
                correlate_result = signal.correlate(psf_stack[:,:,:,i], psf_stack[:,:,:,j], mode="same")
                correlate_result -= correlate_result.min()
                correlate_result /= correlate_result.max()
        
                threshold = 0.50
                correlate_result[correlate_result<threshold] = np.nan
                
                labeled_image, labeled_counts = ndimage.label(~np.isnan(correlate_result))
#                print(labeled_counts)
                # protects against > 1 peak in the cross correlation results
                # shouldn't happen anyway, but at least avoid fitting a single to multi-modal data
                if labeled_counts > 1:
                    max_order = np.argsort(ndimage.maximum(correlate_result, labeled_image, np.arange(labeled_counts)+1))+1
                    correlate_result[labeled_image!=max_order[0]] = np.nan
                    
                
                output_cross_corr_images[:,:,:,counter] = np.nan_to_num(correlate_result)
                
                dims = list()
                for _, dim in enumerate(correlate_result.shape):
                    dims.append(np.arange(dim))
                    dims[-1] = dims[-1] - dims[-1].mean()
                
#                peaks = np.nonzero(correlate_result==np.nanmax(correlate_result))
                if self.peak_detect == "Gaussian":
                    res = optimize.least_squares(guassian_nd_error,
                                                 [1, 0, 0, 5., 0, 5., 0, 30.],
                                                 args=(dims, correlate_result))                
                    output_cross_corr_images_fitted[:,:,:,counter] = gaussian_nd(res.x, dims)
#                    print("Gaussian")
#                    print("chi2: {}".format(np.sum(np.square(res.fun))/(res.fun.shape[0]-8)))
#                    print("fitted parameters: {}".format(res.x))
                        
    #                res = optimize.least_squares(guassian_sq_nd_error,
    #                                             [1, 0, 0, 3., 0, 3., 0, 20.],
    #                                             args=(dims, correlate_result))                
    #                output_cross_corr_images_fitted[:,:,:,counter] = gaussian_sq_nd(res.x, dims)
    #                print("Gaussian 2")
    #                print("chi2: {}".format(np.sum(np.square(res.fun))/(res.fun.shape[0]-8)))
    #                print("fitted parameters: {}".format(res.x))
    #                
    #                res = optimize.least_squares(lorentzian_nd_error,
    #                                             [1, 0, 0, 2., 0, 2., 0, 10.],
    #                                             args=(dims, correlate_result))                
    #                output_cross_corr_images_fitted[:,:,:,counter] = lorentzian_nd(res.x, dims)
    #                print("lorentzian")
    #                print("chi2: {}".format(np.sum(np.square(res.fun))/(res.fun.shape[0]-8)))
    #                print("fitted parameters: {}".format(res.x))
                    
                    shifts[counter, 0] = res.x[2]
                    shifts[counter, 1] = res.x[4]
                    shifts[counter, 2] = res.x[6]
                elif self.peak_detect == "RBF":
                    rbf_interpolator = build_rbf(dims, correlate_result)
                    res = optimize.minimize(rbf_nd_error, [correlate_result.shape[0]*0.5, correlate_result.shape[1]*0.5, correlate_result.shape[2]*0.5], args=rbf_interpolator)
                    output_cross_corr_images_fitted[:,:,:,counter] = rbf_nd(rbf_interpolator, dims)
#                    print(res.x)
                    shifts[counter, :] = res.x
                else:
                    raise Exception("peak founding method not recognised")

#                print("fitted parameters: {}".format(res.x))
        
                counter += 1
                
        self._namespace[self.output_cross_corr_images] = ImageStack(output_cross_corr_images)
        self._namespace[self.output_cross_corr_images_fitted] = ImageStack(output_cross_corr_images_fitted)
        
        drifts = np.matmul(np.linalg.pinv(coefs), shifts)
        residuals = np.matmul(coefs, drifts) - shifts
        residuals_dist = np.linalg.norm(residuals, axis=1)
        
#        shift_max = self.rcc_tolerance * 1E3 / mdh['voxelsize.x']
        shift_max = drift_tolerance
        # Sort and mask residual errors
        residuals_arg = np.argsort(-residuals_dist)        
        residuals_arg = residuals_arg[residuals_dist[residuals_arg] > shift_max]

        # Remove coefs rows
        # Descending from largest residuals to small
        # Only if matrix remains full rank
        coefs_temp = np.empty_like(coefs)
        counter = 0
        for i, index in enumerate(residuals_arg):
            coefs_temp[:] = coefs
            coefs_temp[index, :] = 0
            if np.linalg.matrix_rank(coefs_temp) == coefs.shape[1]:
                coefs[:] = coefs_temp
        #                print("index {} with residual of {} removed".format(index, residuals_dist[index]))
                counter += 1
            else:
                print("Could not remove all residuals over shift_max threshold.")
                break
        print("removed {} in total".format(counter))
        drifts = np.matmul(np.linalg.pinv(coefs), shifts)
       
        drifts = np.pad(drifts, [[1,0],[0,0]], 'constant', constant_values=0)
        np.cumsum(drifts, axis=0, out=drifts)
        
        psf_stack_mean = psf_stack / psf_stack.mean(axis=(0,1,2), keepdims=True)
        psf_stack_mean = psf_stack_mean.mean(axis=3)
        psf_stack_mean *= psf_stack_mean > psf_stack_mean.max() * 0.5        
        center_offset = ndimage.center_of_mass(psf_stack_mean) - np.asarray(psf_stack_mean.shape)*0.5
#        print(center_offset)
        
#        print drifts.shape
#        print stats.trim_mean(drifts, 0.25, axis=0)
#        drifts = drifts - stats.trim_mean(drifts, 0.25, axis=0)
        
        drifts = drifts - center_offset
                
        if True:
            try:
#                from matplotlib import pyplot
                fig, axes = pyplot.subplots(1, 2, figsize=(6,3))
#                new_residuals = np.matmul(coefs, drifts) - shifts
#                new_residuals_dist = np.linalg.norm(new_residuals, axis=1)
#                # print new_residuals_dist
#                pyplot.hist(new_residuals_dist[coefs.any(axis=1)], 100)
#                print drifts                
#                limits = np.max(np.abs(drifts), axis=0)
                
                axes[0].scatter(drifts[:,0], drifts[:,1], s=50)
#                axes[0].set_xlim(-limits[0], limits[0])
#                axes[0].set_ylim(-limits[1], limits[1])
                axes[0].set_xlabel('x')
                axes[0].set_ylabel('y')                
                axes[1].scatter(drifts[:,0], drifts[:,2], s=50)
                axes[1].set_xlabel('x')
                axes[1].set_ylabel('z')
                
                for ax in axes:
#                    ax.set_xlim(-1, 1)
#                    ax.set_ylim(-1, 1)
                    ax.axvline(0, color='red', ls='--')
                    ax.axhline(0, color='red', ls='--')                    
                
                fig.tight_layout()
            except Exception as e:
                print e
            
            
        return drifts
    
    def shift_images(self, psf_stack, shifts):
        kx = (np.fft.fftfreq(psf_stack.shape[0])) 
        ky = (np.fft.fftfreq(psf_stack.shape[1]))
        kz = (np.fft.fftfreq(psf_stack.shape[2]))
        kx, ky, kz = np.meshgrid(kx, ky, kz, indexing='ij')
        
        shifted_images = np.zeros_like(psf_stack)
        for i in np.arange(psf_stack.shape[3]):
            psf = psf_stack[:,:,:,i]
            ft_image = np.fft.fftn(psf)
            shift = shifts[i]
            shifted_images[:,:,:,i] = np.abs(np.fft.ifftn(ft_image*np.exp(-2j*np.pi*(kx*shift[0] + ky*shift[1] + kz*shift[2]))))
#            shifted_images.append(shifted_image)
        
        return shifted_images
class RCCDriftCorrectionBase(CacheCleanupModule):
    """    
    Performs drift correction using redundant cross-correlation from
    Wang et al. Optics Express 2014 22:13 (Bo Huang's RCC algorithm).
    Base class for other RCC recipes.
    Can take cached fft input (as filename, not an 'input').
    Only output drift as tuple of time points, and drift amount.
    Currently not registered by itself since not very usefule.
    """

    cache_fft = File("rcc_cache.bin")
    method = Enum(['RCC', 'MCC', 'DCC'])
    # redundant cross-corelation, mean cross-correlation, direct cross-correlation
    shift_max = Float(5)  # nm
    corr_window = Int(5)
    multiprocessing = Bool()
    debug_cor_file = File()

    output_drift = Output('drift')
    output_drift_plot = Output('drift_plot')

    # if debug_cor_file not blank, filled with imagestack of cross correlation
    output_cross_cor = Output('cross_cor')

    def calc_corr_drift_from_ft_images(self, ft_images):
        n_steps = ft_images.shape[0]

        # Matrix equation coefficient matrix
        # Shape can be predetermined based on method
        if self.method == "DCC":
            coefs_size = n_steps - 1
        elif self.corr_window > 0:
            coefs_size = n_steps * self.corr_window - self.corr_window * (
                self.corr_window + 1) // 2
        else:
            coefs_size = n_steps * (n_steps - 1) // 2
        coefs = np.zeros((coefs_size, n_steps - 1))
        shifts = np.zeros((coefs_size, 3))

        counter = 0

        ft_1_cache = list()
        ft_2_cache = list()
        autocor_shift_cache = list()

        #        print self.debug_cor_file
        if not self.debug_cor_file == "":
            cc_file_shape = [
                shifts.shape[0], ft_images.shape[1], ft_images.shape[2],
                (ft_images.shape[3] - 1) * 2
            ]

            # flatten shortest dimension to reduce cross correlation to 2d images for easier debugging
            min_arg = min(enumerate(cc_file_shape[1:]),
                          key=lambda x: x[1])[0] + 1
            cc_file_shape.pop(min_arg)
            cc_file_args = (self.debug_cor_file, np.float,
                            tuple(cc_file_shape))
            cc_file = np.memmap(cc_file_args[0],
                                dtype=cc_file_args[1],
                                mode="w+",
                                shape=cc_file_args[2])
            #            del cc_file
            cc_args = zip(range(shifts.shape[0]),
                          (cc_file_args, ) * shifts.shape[0])
        else:
            cc_args = (None, ) * shifts.shape[0]

        # For each ft image, calculate correlation
        for i in np.arange(0, n_steps - 1):
            if self.method == "DCC" and i > 0:
                break

            ft_1 = ft_images[i, :, :]

            autocor_shift = calc_shift(ft_1, ft_1)

            for j in np.arange(i + 1, n_steps):
                if (self.method != "DCC") and (self.corr_window > 0) and (
                        j - i > self.corr_window):
                    break

                ft_2 = ft_images[j, :, :]

                coefs[counter, i:j] = 1

                # if multiprocessing, use cache when defined
                if self.multiprocessing:
                    # if reading ft_images from cache, replace ft_1 and ft_2 with their indices
                    if not self.cache_fft == "":
                        ft_1 = i
                        ft_2 = j

                    ft_1_cache.append(ft_1)
                    ft_2_cache.append(ft_2)
                    autocor_shift_cache.append(autocor_shift)
                else:
                    shifts[counter, :] = calc_shift(ft_1, ft_2, autocor_shift,
                                                    None, cc_args[counter])

                    if ((counter + 1) % max(coefs_size // 5, 1) == 0):
                        print(
                            "{:.2f} s. Completed calculating {} of {} total shifts."
                            .format(time.time() - self._start_time,
                                    counter + 1, coefs_size))

                counter += 1

        if self.multiprocessing:
            args = zip(
                range(len(autocor_shift_cache)), ft_1_cache, ft_2_cache,
                autocor_shift_cache,
                len(ft_1_cache) *
                ((self.cache_fft, ft_images.dtype, ft_images.shape), ),
                cc_args)
            for i, (j, res) in enumerate(
                    self._pool.imap_unordered(calc_shift_helper, args)):
                shifts[j, ] = res

                if ((i + 1) % max(coefs_size // 5, 1) == 0):
                    print(
                        "{:.2f} s. Completed calculating {} of {} total shifts."
                        .format(time.time() - self._start_time, i + 1,
                                coefs_size))

        print("{:.2f} s. Finished calculating all shifts.".format(
            time.time() - self._start_time))
        print("{:,} bytes".format(coefs.nbytes))
        print("{:,} bytes".format(shifts.nbytes))

        if not self.debug_cor_file == "":
            # move time axis for ImageStack
            cc_file = np.moveaxis(cc_file, 0, 2)
            self.trait_setq(**{"_cc_image": ImageStack(data=cc_file.copy())})
            del cc_file
        else:
            self.trait_setq(**{"_cc_image": None})

        assert (np.all(np.any(
            coefs, axis=1))), "Coefficient matrix filled less than expected."

        mask = np.where(~np.isnan(shifts).any(axis=1))[0]
        if len(mask) < shifts.shape[0]:
            print("Removed {} cross correlations due to bad/missing data?".
                  format(shifts.shape[0] - len(mask)))
            coefs = coefs[mask, :]
            shifts = shifts[mask, :]

        assert (coefs.shape[0] > 0) and (
            np.linalg.matrix_rank(coefs) == n_steps -
            1), "Something went wrong with coefficient matrix. Not full rank."

        return shifts, coefs  # shifts.shape[0] is n_steps - 1

    def rcc(
        self,
        shift_max,
        t_shift,
        shifts,
        coefs,
    ):
        """
            Should probably rename function.
            Takes cross correlation results and calculates shifts.
        """

        print("{:.2f} s. About to start solving shifts array.".format(
            time.time() - self._start_time))

        # Estimate drift
        drifts = np.matmul(np.linalg.pinv(coefs), shifts)
        #        print(t_shift)
        #        print(drifts)

        print("{:.2f} s. Done solving shifts array.".format(time.time() -
                                                            self._start_time))

        if self.method == "RCC":

            # Calculate residual errors
            residuals = np.matmul(coefs, drifts) - shifts
            residuals_dist = np.linalg.norm(residuals, axis=1)

            # Sort and mask residual errors
            residuals_arg = np.argsort(-residuals_dist)
            residuals_arg = residuals_arg[
                residuals_dist[residuals_arg] > shift_max]

            # Remove coefs rows
            # Descending from largest residuals to small
            # Only if matrix remains full rank
            coefs_temp = np.empty_like(coefs)
            counter = 0
            for i, index in enumerate(residuals_arg):
                coefs_temp[:] = coefs
                coefs_temp[index, :] = 0
                if np.linalg.matrix_rank(coefs_temp) == coefs.shape[1]:
                    coefs[:] = coefs_temp
                    #                print("index {} with residual of {} removed".format(index, residuals_dist[index]))
                    counter += 1
                else:
                    print(
                        "Could not remove all residuals over shift_max threshold."
                    )
                    break
            print("removed {} in total".format(counter))

            # Estimate drift again
            drifts = np.matmul(np.linalg.pinv(coefs), shifts)

            print("{:.2f} s. RCC completed. Repeated solving shifts array.".
                  format(time.time() - self._start_time))

        # pad with 0 drift for first time point
        drifts = np.pad(drifts, [[1, 0], [0, 0]],
                        'constant',
                        constant_values=0)

        return t_shift, drifts

    def _execute(self, namespace):
        # dervied versions of RCC need to override this method
        # 'execute' of this RCC base class is not throughly tested as its use is probably quite limited.

        #        from PYME.util import mProfile

        self._start_time = time.time()
        print("Starting drift correction module.")

        if self.multiprocessing:
            proccess_count = np.clip(multiprocessing.cpu_count() - 1, 1, None)
            self._pool = multiprocessing.Pool(processes=proccess_count)


#        mProfile.profileOn(['localisations.py'])

        drift_res = self.calc_corr_drift_from_ft_images(self.cache_fft)
        t_shift, shifts = self.rcc(self.shift_max, *drift_res)
        #        mProfile.profileOff()
        #        mProfile.report()

        if self.multiprocessing:
            self._pool.close()
            self._pool.join()

        # convert frame-to-frame drift to drift from origin
        shifts = np.cumsum(shifts, 0)

        namespace[self.output_drift] = t_shift, shifts
class AveragePSF(ModuleBase):
    """
        Input stacks of PSF and return the (normalized) average PSF.
        Additional filter based on max error/residual between image and averaged image.
    """
    
    inputName = Input('psf_aligned')
    normalize_intensity = Bool(False)
#    normalize_z = Bool(False)
    output_var_image = Output('psf_var')
#    smoothing_method = Enum(['RBF', 'Gaussian'])
#    output_var_image_norm = Output('psf_var_norm')
    gaussian_filter = List(Float, [0, 0, 0], 3, 3)
    residual_threshold = Float(0.1)
    output_images = Output('psf_combined')
    
    def execute(self, namespace):        
        ims = namespace[self.inputName]
        psf_raw = ims.data[:,:,:,:]
        
        # always normalize first, since needed for statistics
        psf_raw_norm = psf_raw.copy()
#        if self.normalize_intensity == True:
        psf_raw_norm /= psf_raw_norm.max(axis=(0,1,2), keepdims=True)
        psf_raw_norm /= psf_raw_norm.sum(axis=(0,1), keepdims=True)
        psf_raw_norm -= psf_raw_norm.min()
        psf_raw_norm /= psf_raw_norm.max()
        
        residual_max = np.abs(psf_raw_norm - psf_raw_norm.mean(axis=3, keepdims=True)).max(axis=(0,1,2))
        print(residual_max)
        mask = residual_max < self.residual_threshold
        print "images ignore: {}".format(np.argwhere(~mask)[:,0])
        print mask
        psf_raw_norm = psf_raw_norm[:,:,:,mask]
        print(psf_raw_norm.shape)
        psf_raw_norm -= psf_raw_norm.min()
        psf_raw_norm /= psf_raw_norm.max()
        
        psf_var = psf_raw_norm.var(axis=3)  
        
#        psf_var_norm = psf_var / psf_combined.mean(axis=3)
        namespace[self.output_var_image] = ImageStack(psf_var, mdh=ims.mdh)
#        namespace[self.output_var_image_norm] = ImageStack(np.nan_to_num(psf_var_norm), mdh=ims.mdh)
        
        # if requested not to normalize, revert back to original data
        if not self.normalize_intensity:
            psf_raw_norm = psf_raw.copy()[:,:,:,mask]
        
        psf_combined = psf_raw_norm.mean(axis=3)
        psf_combined -= psf_combined.min()
        psf_combined /= psf_combined.max()
        
#        if self.smoothing_method == 'RBF':
#            dims = [np.arange(i) for i in psf_combined.shape]
            
#        elif self.smoothing_method == 'Gaussian' and
        if np.any(np.asarray(self.gaussian_filter)!=0):
            psf_processed = ndimage.gaussian_filter(psf_combined, self.gaussian_filter)
        else:
            psf_processed = psf_combined
            
        psf_processed -= psf_processed.min()
        psf_processed /= psf_processed.max()
        
        new_mdh = None
        try:
            new_mdh = MetaDataHandler.NestedClassMDHandler(ims.mdh)
            new_mdh["PSFExtraction.GaussianFilter"] = self.gaussian_filter
            new_mdh["PSFExtraction.NormalizeIntensity"] = self.normalize_intensity            
        except Exception as e:
            print(e)
        namespace[self.output_images] = ImageStack(psf_processed, mdh=new_mdh)
        
        if True:
            fig, axes = pyplot.subplots(2, 3, figsize=(9,6))
            axes[0,0].set_title('X')
            axes[0,0].plot(psf_raw_norm[:, psf_raw_norm.shape[1]//2, psf_raw_norm.shape[2]//2, :])
            axes[1,0].plot(psf_combined[:, psf_combined.shape[1]//2, psf_combined.shape[2]//2], lw=1, color='red')
            axes[1,0].plot(psf_processed[:, psf_processed.shape[1]//2, psf_processed.shape[2]//2], lw=1, ls='--', color='black')
            axes[0,1].set_title('Y')
            axes[0,1].plot(psf_raw_norm[psf_raw_norm.shape[0]//2, :, psf_raw_norm.shape[2]//2, :])
            axes[1,1].plot(psf_combined[psf_combined.shape[0]//2, :, psf_combined.shape[2]//2], lw=1, color='red')
            axes[1,1].plot(psf_processed[psf_processed.shape[0]//2, :, psf_processed.shape[2]//2], lw=1, ls='--', color='black')
            axes[0,2].set_title('Z')
            axes[0,2].plot(psf_raw_norm[psf_raw_norm.shape[0]//2, psf_raw_norm.shape[1]//2, :, :])
            axes[1,2].plot(psf_combined[psf_combined.shape[0]//2, psf_combined.shape[1]//2, :], lw=1, color='red')
            axes[1,2].plot(psf_processed[psf_processed.shape[0]//2, psf_processed.shape[1]//2, :], lw=1, ls='--', color='black')
            
            fig.tight_layout()
            
            fig, ax = pyplot.subplots(1, 1, figsize=(4,3))
            ax.hist(residual_max, bins=20)
            ax.axvline(self.residual_threshold, color='red', ls='--')
Exemplo n.º 27
0
class LayerWrapper(HasTraits):
    cmap = Enum(*cm.cmapnames, default='gist_rainbow')
    clim = ListFloat([0, 1])
    alpha = Float(1.0)
    visible = Bool(True)
    method = Enum(*ENGINES.keys())
    engine = Instance(layers.BaseLayer)
    dsname = CStr('output')

    def __init__(self,
                 pipeline,
                 method='points',
                 ds_name='',
                 cmap='gist_rainbow',
                 clim=[0, 1],
                 alpha=1.0,
                 visible=True,
                 method_args={}):
        self._pipeline = pipeline
        #self._namespace=getattr(pipeline, 'namespace', {})
        #self.dsname = None
        self.engine = None

        self.cmap = cmap
        self.clim = clim
        self.alpha = alpha

        self.visible = visible

        self.on_update = dispatch.Signal()

        self.on_trait_change(lambda: self.on_update.send(self), 'visible')
        self.on_trait_change(self.update, 'cmap, clim, alpha, dsname')
        self.on_trait_change(self._set_method, 'method')

        #self.set_datasource(ds_name)
        self.dsname = ds_name

        self._eng_params = dict(method_args)
        self.method = method

        self._pipeline.onRebuild.connect(self.update)

    @property
    def _namespace(self):
        return self._pipeline.layer_datasources

    @property
    def bbox(self):
        return self.engine.bbox

    @property
    def colour_map(self):
        return self.engine.colour_map

    @property
    def data_source_names(self):
        names = []  #'']
        for k, v in self._namespace.items():
            names.append(k)
            if isinstance(v, tabular.ColourFilter):
                for c in v.getColourChans():
                    names.append('.'.join([k, c]))

        return names

    @property
    def datasource(self):
        if self.dsname == '':
            return self._pipeline

        parts = self.dsname.split('.')
        if len(parts) == 2:
            # special case - permit access to channels using dot notation
            # NB: only works if our underlying datasource is a ColourFilter
            ds, channel = parts
            return self._namespace.get(ds, None).get_channel_ds(channel)
        else:
            return self._namespace.get(self.dsname, None)

    def _set_method(self):
        if self.engine:
            self._eng_params = self.engine.get('point_size', 'vertexColour')
            #print(eng_params)

        self.engine = ENGINES[self.method](self._context)
        self.engine.set(**self._eng_params)
        self.engine.on_trait_change(self._update, 'vertexColour')
        self.engine.on_trait_change(self.update)

        self.update()

    # def set_datasource(self, ds_name):
    #     self._dsname = ds_name
    #
    #     self.update()

    def _update(self, *args, **kwargs):
        cdata = self._get_cdata()
        self.clim = [float(cdata.min()), float(cdata.max())]
        #self.update(*args, **kwargs)

    def update(self, *args, **kwargs):
        print('lw update')
        if not (self.engine is None or self.datasource is None):
            self.engine.update_from_datasource(self.datasource,
                                               getattr(cm, self.cmap),
                                               self.clim, self.alpha)
            self.on_update.send(self)

    def render(self, gl_canvas):
        if self.visible:
            self.engine.render(gl_canvas)

    def _get_cdata(self):
        try:
            cdata = self.datasource[self.engine.vertexColour]
        except KeyError:
            cdata = np.array([0, 1])

        return cdata

    @property
    def default_view(self):
        from traitsui.api import View, Item, Group, InstanceEditor, EnumEditor
        from PYME.ui.custom_traits_editors import HistLimitsEditor, CBEditor

        return View(
            [
                Group([
                    Item('dsname',
                         label='Data',
                         editor=EnumEditor(values=self.data_source_names)),
                ]),
                Item('method'),
                #Item('_'),
                Group([
                    Item('engine',
                         style='custom',
                         show_label=False,
                         editor=InstanceEditor(
                             view=self.engine.view(self.datasource.keys()))),
                ]),
                #Item('engine.color_key', editor=CBEditor(choices=self.datasource.keys())),
                Group([
                    Item('clim',
                         editor=HistLimitsEditor(data=self._get_cdata),
                         show_label=False),
                ]),
                Group([
                    Item('cmap', label='LUT'),
                    Item('alpha'),
                    Item('visible')
                ],
                      orientation='horizontal',
                      layout='flow')
            ], )
        #buttons=['OK', 'Cancel'])

    def default_traits_view(self):
        return self.default_view
Exemplo n.º 28
0
class TetrahedraRenderLayer(VertexRenderLayer):
    """
    This program draws a WareFrame of the given points. They are interpreted as triangles.
    """
    z_rescale = Float(1.0)
    size_cutoff = Float(1000.)
    internal_cull = Bool(True)
    wireframe = Bool(False)

    DRAW_MODE = GL_TRIANGLES

    def __init__(self,
                 x=None,
                 y=None,
                 z=None,
                 colors=None,
                 color_map=None,
                 size_cutoff=None,
                 internal_cull=None,
                 z_rescale=None,
                 alpha=None,
                 is_wire_frame=False):
        super(TetrahedraRenderLayer, self).__init__(colors=colors,
                                                    color_map=color_map,
                                                    alpha=alpha)
        if size_cutoff:
            self.size_cutoff = size_cutoff

        if internal_cull:
            self.internal_cull = internal_cull

        if z_rescale:
            self.z_rescale = z_rescale

        if x:
            p, a, n = gen3DTriangs(x,
                                   y,
                                   z / self.z_rescale,
                                   self.size_cutoff,
                                   internalCull=self.internal_cull)
            if colors == 'z':
                colors = p[:, 2]
            else:
                colors = 1. / a
            color_limit = [colors.min(), colors.max()]
            self.update_data(x,
                             y,
                             z,
                             colors,
                             cmap=cmap,
                             clim=clim,
                             alpha=alpha)
        else:
            pass

        #self.set_values(p, n)
        if is_wire_frame:
            self.set_shader_program(WireFrameShaderProgram)
        else:
            self.set_shader_program(GouraudShaderProgram)

    def update_from_datasource(self, ds, cmap=None, clim=None, alpha=1.0):
        x, y = ds[self.x_key], ds[self.y_key]

        if not self.z_key is None:
            z = ds[self.z_key]
        else:
            z = 0 * x

        p, a, n = gen3DTriangs(x,
                               y,
                               z / self.z_rescale,
                               self.size_cutoff,
                               internalCull=self.internal_cull)

        if False:  #not self.vertexColour == '':
            #todo - set up for interpolated triangles
            c = ds[self.vertexColour]
        else:
            c = 1. / a

        self.update_data(p[:, 0],
                         p[:, 1],
                         p[:, 2],
                         c,
                         cmap=cmap,
                         clim=clim,
                         alpha=alpha)

    def update_data(self,
                    x=None,
                    y=None,
                    z=None,
                    colors=None,
                    cmap=None,
                    clim=None,
                    alpha=1.0):
        self._vertices = None
        self._normals = None
        self._colors = None
        self._color_map = None
        self._color_limit = 0
        self._alpha = 0

        if x is not None and y is not None and z is not None:
            vertices = np.vstack((x.ravel(), y.ravel(), z.ravel()))
            vertices = vertices.T.ravel().reshape(len(x.ravel()), 3)
            normals = -0.69 * np.ones(vertices.shape)
        else:
            vertices = None
            normals = None

        if clim is not None and colors is not None and clim is not None:
            cs_ = ((colors - clim[0]) / (clim[1] - clim[0]))
            cs = cmap(cs_)
            cs[:, 3] = alpha

            cs = cs.ravel().reshape(len(colors), 4)
        else:
            #cs = None
            if not vertices is None:
                cs = np.ones((vertices.shape[0], 4), 'f')
            else:
                cs = None
            color_map = None
            color_limit = None

        self.set_values(vertices, normals, cs, cmap, clim, alpha)

    def render(self, gl_canvas):
        """

        Parameters
        ----------
        gl_canvas
            nothing of the canvas is used. That's how it should be.
        Returns
        -------

        """
        with self.shader_program:
            n_vertices = self.get_vertices().shape[0]

            glVertexPointerf(self.get_vertices())
            glNormalPointerf(self.get_normals())
            glColorPointerf(self.get_colors())

            glPushMatrix()
            glDrawArrays(self.DRAW_MODE, 0, n_vertices)

            glPopMatrix()