Ejemplo n.º 1
0
class PSFSettings(HasTraits):
    wavelength_nm = Float(700.)
    NA = Float(1.47)
    vectorial = Bool(False)
    zernike_modes = Dict()
    zernike_modes_lower = Dict()
    phases = List([0, .5, 1, 1.5])
    four_pi = Bool(False)

    def default_traits_view(self):
        from traitsui.api import View, Item
        #from PYME.ui.custom_traits_editors import CBEditor

        return View(Item(name='wavelength_nm'),
                    Item(name='NA'),
                    Item(name='vectorial'),
                    Item(name='four_pi', label='4Pi'),
                    Item(name='zernike_modes'),
                    Item(name='zernike_modes_lower',
                         visible_when='four_pi==True'),
                    Item(name='phases',
                         visible_when='four_pi==True',
                         label='phases/pi'),
                    resizable=True,
                    buttons=['OK'])
class CombineBeadStacks(ModuleBase):
    """
        Combine multiply bead stacks in the 4th dimension.
        X, Y, Z must be identical.
    """
    
    inputName = Input('dummy')
    
    files = List(File, ['', ''], 2)
    cache = File()
    
    outputName = Output('bead_images')
    
    def execute(self, namespace):
        
        ims = ImageStack(filename=self.files[0])
        dims = np.asarray(ims.data.shape, dtype=np.long)
        dims[3] = 0
        dtype_ = ims.data[:,0,0,0].dtype
        mdh = ims.mdh
        del ims
        
        for fil in self.files:
            ims = ImageStack(filename=fil)
            dims[3] += ims.data.shape[3]
            del ims
        
        if self.cache != '':
            raw_data = np.memmap(self.cache, dtype=dtype_, mode='w+', shape=tuple(dims))
        else:
            raw_data = np.zeros(shape=tuple(dims), dtype=dtype_)
        
        counter = 0
        for fil in self.files:
            ims = ImageStack(filename=fil)
            c_len = ims.data.shape[3]
            data = ims.data[:,:,:,:]
            data.shape += (1,) * (4 - data.ndim)
            raw_data[:,:,:,counter:counter+c_len] = data
            counter += c_len
            del ims
            
        new_mdh = None
        try:
            new_mdh = MetaDataHandler.NestedClassMDHandler(mdh)
            new_mdh["PSFExtraction.SourceFilenames"] = self.files
        except Exception as e:
            print(e)
            
        namespace[self.outputName] = ImageStack(data=raw_data, mdh=new_mdh)
class Binning(CacheCleanupModule):
    """
    Downsample 3D data (mean).
    X, Y pixels that does't fill a full bin are dropped.
    Pixels in the 3rd dimension can have a partially filled bin.
        
    Inputs
    ------
    inputName : ImageStack
    
    Outputs
    -------
    outputName : ImageStack
    
    Parameters
    ----------
    x_start : int
        Starting index in x.
    x_end : Float
        Stopping index in x.
    y_start : Float
        Starting index in y.
    y_end : Float
        Stopping index in y.
    binsize : Float
        Bin size.
    cache_bin : File
        Use file as disk cache if provided.
    """

    inputName = Input('input')
    x_start = Int(0)
    x_end = Int(-1)
    y_start = Int(0)
    y_end = Int(-1)
    #    z_start = Int(0)
    #    z_end = Int(-1)
    binsize = List([1, 1, 1], minlen=3, maxlen=3)
    cache_bin = File("binning_cache_2.bin")
    outputName = Output('binned_image')

    def _execute(self, namespace):
        self._start_time = time.time()
        ims = namespace[self.inputName]

        binsize = np.asarray(self.binsize, dtype=np.int)
        #        print (binsize)

        # unconventional, end stop in inclusive
        x_slice = np.arange(ims.data.shape[0] + 1)[slice(
            self.x_start, self.x_end, 1)]
        y_slice = np.arange(ims.data.shape[1] + 1)[slice(
            self.y_start, self.y_end, 1)]
        x_slice = x_slice[:x_slice.shape[0] // binsize[0] * binsize[0]]
        y_slice = y_slice[:y_slice.shape[0] // binsize[1] * binsize[1]]
        #        print x_slice, len(x_slice)
        #        print y_slice, len(y_slice)
        bincounts = np.asarray([
            len(x_slice) // binsize[0],
            len(y_slice) // binsize[1], -(-ims.data.shape[2] // binsize[2])
        ],
                               dtype=np.long)

        x_slice_ind = slice(x_slice[0], x_slice[-1] + 1)
        y_slice_ind = slice(y_slice[0], y_slice[-1] + 1)

        #        print (bincounts)
        new_shape = np.stack([bincounts, binsize], -1).flatten()
        #        print(new_shape)

        # need to wrap this to work for multiply color channel images
        #        binned_image = ims.data[:,:,:].reshape(new_shape)
        dtype = ims.data[:, :, 0].dtype

        #        print bincounts
        binned_image = np.memmap(self.cache_bin,
                                 dtype=dtype,
                                 mode='w+',
                                 shape=tuple(
                                     np.asarray(bincounts, dtype=np.long)))
        #        print binned_image.shape

        new_shape_one_chunk = new_shape.copy()
        new_shape_one_chunk[4] = 1
        new_shape_one_chunk[5] = -1
        #        print new_shape_one_chunk
        progress = 0.2 * ims.data.shape[2]
        #        print
        for i, f in enumerate(np.arange(0, ims.data.shape[2], binsize[2])):
            raw_data_chunk = ims.data[x_slice_ind, y_slice_ind,
                                      f:f + binsize[2]].squeeze()

            binned_image[:, :,
                         i] = raw_data_chunk.reshape(new_shape_one_chunk).mean(
                             (1, 3, 5)).squeeze()

            if (f + binsize[2] >= progress):
                binned_image.flush()
                progress += 0.2 * ims.data.shape[2]
                print("{:.2f} s. Completed binning {} of {} total images.".
                      format(time.time() - self._start_time,
                             min(f + binsize[2], ims.data.shape[2]),
                             ims.data.shape[2]))


#        print(type(binned_image))
        im = ImageStack(binned_image, titleStub=self.outputName)
        #        print(type(im.data))
        im.mdh.copyEntriesFrom(ims.mdh)
        im.mdh['Parent'] = ims.filename
        try:
            ### Metadata must be logged correctly for the measured drift to be applicable to the source image
            im.mdh['voxelsize.x'] *= binsize[0]
            im.mdh['voxelsize.y'] *= binsize[1]
            #            im.mdh['voxelsize.z'] *= binsize[2]
            if 'recipe.binning' in im.mdh.keys():
                im.mdh['recipe.binning'] = binsize * im.mdh['recipe.binning']
            else:
                im.mdh['recipe.binning'] = binsize
        except:
            pass

        namespace[self.outputName] = im
Ejemplo n.º 4
0
class PointCloudRenderLayer(EngineLayer):
    """
    A layer for viewing point-cloud data, using one of 3 engines (indicated above)
    
    """

    # properties to show in the GUI. Note that we also inherit 'visible' from BaseLayer
    vertexColour = CStr('', desc='Name of variable used to colour our points')
    point_size = Float(30.0, desc='Rendered size of the points in nm')
    cmap = Enum(*cm.cmapnames,
                default='gist_rainbow',
                desc='Name of colourmap used to colour points')
    clim = ListFloat(
        [0, 1],
        desc='How our variable should be scaled prior to colour mapping')
    alpha = Float(1.0, desc='Point tranparency')
    method = Enum(*ENGINES.keys(), desc='Method used to display points')
    dsname = CStr(
        'output',
        desc=
        'Name of the datasource within the pipeline to use as a source of points'
    )
    _datasource_keys = List()
    _datasource_choices = List()

    def __init__(self,
                 pipeline,
                 method='points',
                 dsname='',
                 context=None,
                 **kwargs):
        EngineLayer.__init__(self, context=context, **kwargs)
        self._pipeline = pipeline
        self.engine = None
        self.cmap = 'gist_rainbow'

        self.x_key = 'x'  #TODO - make these traits?
        self.y_key = 'y'
        self.z_key = 'z'

        self.xn_key = 'xn'
        self.yn_key = 'yn'
        self.zn_key = 'zn'

        self._bbox = None

        # define a signal so that people can be notified when we are updated (currently used to force a redraw when
        # parameters change)
        self.on_update = dispatch.Signal()

        # define responses to changes in various traits
        self.on_trait_change(self._update, 'vertexColour')
        self.on_trait_change(lambda: self.on_update.send(self), 'visible')
        self.on_trait_change(self.update,
                             'cmap, clim, alpha, dsname, point_size')
        self.on_trait_change(self._set_method, 'method')

        # update any of our traits which were passed as command line arguments
        self.set(**kwargs)

        # update datasource name and method
        #logger.debug('Setting dsname and method')
        self.dsname = dsname
        self.method = method

        self._set_method()

        # if we were given a pipeline, connect ourselves to the onRebuild signal so that we can automatically update
        # ourselves
        if not self._pipeline is None:
            self._pipeline.onRebuild.connect(self.update)

    @property
    def datasource(self):
        """
        Return the datasource we are connected to (through our dsname property).
        """
        return self._pipeline.get_layer_data(self.dsname)

    def _set_method(self):
        #logger.debug('Setting layer method to %s' % self.method)
        self.engine = ENGINES[self.method](self._context)
        self.update()

    def _get_cdata(self):
        try:
            cdata = self.datasource[self.vertexColour]
        except KeyError:
            cdata = np.array([0, 1])

        return cdata

    def _update(self, *args, **kwargs):
        cdata = self._get_cdata()
        self.clim = [float(cdata.min()), float(cdata.max())]
        #self.update(*args, **kwargs)

    def update(self, *args, **kwargs):
        print('lw update')
        self._datasource_choices = self._pipeline.layer_data_source_names
        if not self.datasource is None:
            self._datasource_keys = sorted(self.datasource.keys())

        if not (self.engine is None or self.datasource is None):
            self.update_from_datasource(self.datasource)
            self.on_update.send(self)

    @property
    def bbox(self):
        return self._bbox

    def update_from_datasource(self, ds):
        x, y = ds[self.x_key], ds[self.y_key]

        if not self.z_key is None:
            try:
                z = ds[self.z_key]
            except KeyError:
                z = 0 * x
        else:
            z = 0 * x

        if not self.vertexColour == '':
            c = ds[self.vertexColour]
        else:
            c = 0 * x

        if self.xn_key in ds.keys():
            xn, yn, zn = ds[self.xn_key], ds[self.yn_key], ds[self.zn_key]
            self.update_data(x,
                             y,
                             z,
                             c,
                             cmap=getattr(cm, self.cmap),
                             clim=self.clim,
                             alpha=self.alpha,
                             xn=xn,
                             yn=yn,
                             zn=zn)
        else:
            self.update_data(x,
                             y,
                             z,
                             c,
                             cmap=getattr(cm, self.cmap),
                             clim=self.clim,
                             alpha=self.alpha)

    def update_data(self,
                    x=None,
                    y=None,
                    z=None,
                    colors=None,
                    cmap=None,
                    clim=None,
                    alpha=1.0,
                    xn=None,
                    yn=None,
                    zn=None):
        self._vertices = None
        self._normals = None
        self._colors = None
        self._color_map = None
        self._color_limit = 0
        self._alpha = 0
        if x is not None and y is not None and z is not None:
            vertices = np.vstack((x.ravel(), y.ravel(), z.ravel()))
            vertices = vertices.T.ravel().reshape(len(x.ravel()), 3)

            if not xn is None:
                normals = np.vstack(
                    (xn.ravel(), yn.ravel(),
                     zn.ravel())).T.ravel().reshape(len(x.ravel()), 3)
            else:
                normals = -0.69 * np.ones(vertices.shape)

            self._bbox = np.array(
                [x.min(), y.min(),
                 z.min(), x.max(),
                 y.max(), z.max()])
        else:
            vertices = None
            normals = None
            self._bbox = None

        if clim is not None and colors is not None and clim is not None:
            cs_ = ((colors - clim[0]) / (clim[1] - clim[0]))
            cs = cmap(cs_)
            cs[:, 3] = alpha

            cs = cs.ravel().reshape(len(colors), 4)
        else:
            #cs = None
            if not vertices is None:
                cs = np.ones((vertices.shape[0], 4), 'f')
            else:
                cs = None
            color_map = None
            color_limit = None

        self.set_values(vertices, normals, cs, cmap, clim, alpha)

    def set_values(self,
                   vertices=None,
                   normals=None,
                   colors=None,
                   color_map=None,
                   color_limit=None,
                   alpha=None):
        if vertices is not None:
            self._vertices = vertices
        if normals is not None:
            self._normals = normals
        if color_map is not None:
            self._color_map = color_map
        if colors is not None:
            self._colors = colors
        if color_limit is not None:
            self._color_limit = color_limit
        if alpha is not None:
            self._alpha = alpha

    def get_vertices(self):
        return self._vertices

    def get_normals(self):
        return self._normals

    def get_colors(self):
        return self._colors

    def get_color_map(self):
        return self._color_map

    @property
    def colour_map(self):
        return self._color_map

    def get_color_limit(self):
        return self._color_limit

    @property
    def default_view(self):
        from traitsui.api import View, Item, Group, InstanceEditor, EnumEditor, TextEditor
        from PYME.ui.custom_traits_editors import HistLimitsEditor, CBEditor

        return View([
            Group([
                Item('dsname',
                     label='Data',
                     editor=EnumEditor(name='_datasource_choices')),
            ]),
            Item('method'),
            Item(
                'vertexColour',
                editor=EnumEditor(name='_datasource_keys'),
                label='Colour',
                visible_when='cmap not in ["R", "G", "B", "C", "M","Y", "K"]'),
            Group(
                [
                    Item('clim',
                         editor=HistLimitsEditor(data=self._get_cdata,
                                                 update_signal=self.on_update),
                         show_label=False),
                ],
                visible_when='cmap not in ["R", "G", "B", "C", "M","Y", "K"]'),
            Group(
                Item('cmap', label='LUT'),
                Item('alpha',
                     visible_when=
                     "method in ['pointsprites', 'transparent_points']",
                     editor=TextEditor(auto_set=False,
                                       enter_set=True,
                                       evaluate=float)),
                Item('point_size',
                     editor=TextEditor(auto_set=False,
                                       enter_set=True,
                                       evaluate=float)))
        ])
        #buttons=['OK', 'Cancel'])

    def default_traits_view(self):
        return self.default_view
Ejemplo n.º 5
0
class ModuleCollection(HasTraits):
    modules = List()
    execute_on_invalidation = Bool(False)

    def __init__(self, *args, **kwargs):
        HasTraits.__init__(self, *args, **kwargs)

        self.namespace = {}

        # we open hdf files and don't necessarily read their contents into memory - these need to be closed when we
        # either delete the recipe, or clear the namespace
        self._open_input_files = []

        self.recipe_changed = dispatch.Signal()
        self.recipe_executed = dispatch.Signal()

    def invalidate_data(self):
        if self.execute_on_invalidation:
            self.execute()

    def clear(self):
        self.namespace.clear()

    def new_output_name(self, stub):
        count = len([k.startswith(stub) for k in self.namespace.keys()])

        if count == 0:
            return stub
        else:
            return '%s_%d' % (stub, count)

    def dependancyGraph(self):
        dg = {}

        #only add items to dependancy graph if they are not already in the namespace
        #calculated_objects = namespace.keys()

        for mod in self.modules:
            #print mod
            s = mod.inputs

            try:
                s.update(dg[mod])
            except KeyError:
                pass

            dg[mod] = s

            for op in mod.outputs:
                #if not op in calculated_objects:
                dg[op] = {
                    mod,
                }

        return dg

    def reverseDependancyGraph(self):
        dg = self.dependancyGraph()

        rdg = {}

        for k, vs in dg.items():
            for v in vs:
                vdeps = set()
                try:
                    vdeps = rdg[v]
                except KeyError:
                    pass

                vdeps.add(k)
                rdg[v] = vdeps

        return rdg

    def _getAllDownstream(self, rdg, keys):
        """get all the downstream items which depend on the given key"""

        downstream = set()

        next_level = set()

        for k in keys:
            try:
                next_level.update(rdg[k])
            except KeyError:
                pass

        if len(list(next_level)) > 0:

            downstream.update(next_level)

            downstream.update(self._getAllDownstream(rdg, list(next_level)))

        return downstream

    def prune_dependencies_from_namespace(self,
                                          keys_to_prune,
                                          keep_passed_keys=False):
        rdg = self.reverseDependancyGraph()

        if keep_passed_keys:
            downstream = list(self._getAllDownstream(rdg, list(keys_to_prune)))
        else:
            downstream = list(keys_to_prune) + list(
                self._getAllDownstream(rdg, list(keys_to_prune)))

        #print downstream

        for dsi in downstream:
            try:
                self.namespace.pop(dsi)
            except KeyError:
                #the output is not in our namespace, no need to prune
                pass
            except AttributeError:
                #we might not have our namespace defined yet
                pass

    def resolveDependencies(self):
        import toposort
        #build dependancy graph

        dg = self.dependancyGraph()

        #solve the dependency tree
        return toposort.toposort_flatten(dg, sort=False)

    def execute(self, **kwargs):
        #remove anything which is downstream from changed inputs
        #print self.namespace.keys()
        for k, v in kwargs.items():
            #print k, v
            try:
                if not (self.namespace[k] == v):
                    #input has changed
                    print('pruning: ', k)
                    self.prune_dependencies_from_namespace([k])
            except KeyError:
                #key wasn't in namespace previously
                print('KeyError')
                pass

        self.namespace.update(kwargs)

        exec_order = self.resolveDependencies()

        for m in exec_order:
            if isinstance(
                    m,
                    ModuleBase) and not m.outputs_in_namespace(self.namespace):
                try:
                    m.execute(self.namespace)
                except:
                    logger.exception("Error in recipe module: %s" % m)
                    raise

        self.recipe_executed.send_robust(self)

        if 'output' in self.namespace.keys():
            return self.namespace['output']

    @classmethod
    def fromMD(cls, md):
        c = cls()

        moduleNames = set([s.split('.')[0] for s in md.keys()])

        mc = []

        for mn in moduleNames:
            mod = all_modules[mn]()
            mod.set(**md[mn])
            mc.append(mod)

        #return cls(modules=mc)
        c.modules = mc

        return c

    def get_cleaned_module_list(self):
        l = []
        for mod in self.modules:
            #l.append({mod.__class__.__name__: mod.get()})

            ct = mod.class_traits()

            mod_traits_cleaned = {}
            for k, v in mod.get().items():
                if not k.startswith(
                        '_'
                ):  #don't save private data - this is usually used for caching etc ..,
                    try:
                        if (not (v == ct[k].default)) or (k.startswith(
                                'input')) or (k.startswith('output')):
                            #don't save defaults
                            if isinstance(v, dict) and not type(v) == dict:
                                v = dict(v)
                            elif isinstance(v, list) and not type(v) == list:
                                v = list(v)
                            elif isinstance(v, set) and not type(v) == set:
                                v = set(v)

                            mod_traits_cleaned[k] = v
                    except KeyError:
                        # for some reason we have a trait that shouldn't be here
                        pass

            l.append({module_names[mod.__class__]: mod_traits_cleaned})

        return l

    def toYAML(self):
        import yaml

        class MyDumper(yaml.SafeDumper):
            def represent_mapping(self, tag, value, flow_style=None):
                return super(MyDumper,
                             self).represent_mapping(tag, value, False)

        return yaml.dump(self.get_cleaned_module_list(), Dumper=MyDumper)

    def toJSON(self):
        import json
        return json.dumps(self.get_cleaned_module_list())

    def _update_from_module_list(self, l):
        """
        Update from a parsed yaml or json list of modules
        
        It probably makes no sense to call this directly as the format is pretty wack - a
        list of dictionarys each with a single entry, but that is how the yaml parses

        Parameters
        ----------
        l: list
            List of modules as obtained from parsing a yaml recipe,
            Each module is a dictionary mapping with a single e.g.
            [{'Filtering.Filter': {'filters': {'probe': [-0.5, 0.5]}, 'input': 'localizations', 'output': 'filtered'}}]

        Returns
        -------

        """
        mc = []

        if l is None:
            l = []

        for mdd in l:
            mn, md = list(mdd.items())[0]
            try:
                mod = all_modules[mn](self)
            except KeyError:
                # still support loading old recipes which do not use hierarchical names
                # also try and support modules which might have moved
                mod = _legacy_modules[mn.split('.')[-1]](self)

            mod.set(**md)
            mc.append(mod)

        self.modules = mc

        self.recipe_changed.send_robust(self)
        self.invalidate_data()

    @classmethod
    def _from_module_list(cls, l):
        """ A factory method which contains the common logic for loading/creating from either
        yaml or json. Do not call directly"""
        c = cls()
        c._update_from_module_list(l)

        return c

    @classmethod
    def fromYAML(cls, data):
        import yaml

        l = yaml.load(data)
        return cls._from_module_list(l)

    def update_from_yaml(self, data):
        """
        Update from a yaml formatted recipe description

        Parameters
        ----------
        data: str
            either yaml formatted text, or the path to a yaml file.

        Returns
        -------
        None

        """
        import os
        import yaml

        if os.path.isfile(data):
            with open(data) as f:
                data = f.read()

        l = yaml.load(data)
        return self._update_from_module_list(l)

    @classmethod
    def fromJSON(cls, data):
        import json
        return cls._from_module_list(json.loads(data))

    def add_module(self, module):
        self.modules.append(module)
        self.recipe_changed.send_robust(self)

    @property
    def inputs(self):
        ip = set()
        for mod in self.modules:
            ip.update({k for k in mod.inputs if k.startswith('in')})
        return ip

    @property
    def outputs(self):
        op = set()
        for mod in self.modules:
            op.update({k for k in mod.outputs if k.startswith('out')})
        return op

    @property
    def module_outputs(self):
        op = set()
        for mod in self.modules:
            op.update(set(mod.outputs))
        return op

    @property
    def file_inputs(self):
        out = []
        for mod in self.modules:
            out += mod.file_inputs

        return out

    def save(self, context={}):
        """
        Find all OutputModule instances and call their save methods with the recipe context

        Parameters
        ----------
        context : dict
            A context dictionary used to substitute and create variable names.

        """
        for mod in self.modules:
            if isinstance(mod, OutputModule):
                mod.save(self.namespace, context)

    def gather_outputs(self, context={}):
        """
        Find all OutputModule instances and call their generate methods with the recipe context

        Parameters
        ----------
        context : dict
            A context dictionary used to substitute and create variable names.

        """

        outputs = []

        for mod in self.modules:
            if isinstance(mod, OutputModule):
                out = mod.generate(self.namespace, context)

                if not out is None:
                    outputs.append(out)

        return outputs

    def loadInput(self, filename, key='input'):
        """Load input data from a file and inject into namespace

        Currently only handles images (anything you can open in dh5view). TODO -
        extend to other types.
        """
        #modify this to allow for different file types - currently only supports images
        from PYME.IO import unifiedIO
        import os
        extension = os.path.splitext(filename)[1]
        if extension in ['.h5r', '.h5', '.hdf']:
            import tables
            from PYME.IO import MetaDataHandler
            from PYME.IO import tabular

            with unifiedIO.local_or_temp_filename(filename) as fn:
                with tables.open_file(fn, mode='r') as h5f:
                    #make sure our hdf file gets closed

                    key_prefix = '' if key == 'input' else key + '_'

                    try:
                        mdh = MetaDataHandler.NestedClassMDHandler(
                            MetaDataHandler.HDFMDHandler(h5f))
                    except tables.FileModeError:  # Occurs if no metadata is found, since we opened the table in read-mode
                        logger.warning(
                            'No metadata found, proceeding with empty metadata'
                        )
                        mdh = MetaDataHandler.NestedClassMDHandler()

                    for t in h5f.list_nodes('/'):
                        # FIXME - The following isinstance tests are not very safe (and badly broken in some cases e.g.
                        # PZF formatted image data, Image data which is not in an EArray, etc ...)
                        # Note that EArray is only used for streaming data!
                        # They should ideally be replaced with more comprehensive tests (potentially based on array or dataset
                        # dimensionality and/or data type) - i.e. duck typing. Our strategy for images in HDF should probably
                        # also be improved / clarified - can we use hdf attributes to hint at the data intent? How do we support
                        # > 3D data?

                        if isinstance(t, tables.VLArray):
                            from PYME.IO.ragged import RaggedVLArray

                            rag = RaggedVLArray(
                                h5f, t.name, copy=True
                            )  #force an in-memory copy so we can close the hdf file properly
                            rag.mdh = mdh

                            self.namespace[key_prefix + t.name] = rag

                        elif isinstance(t, tables.table.Table):
                            #  pipe our table into h5r or hdf source depending on the extension
                            tab = tabular.H5RSource(
                                h5f, t.name
                            ) if extension == '.h5r' else tabular.HDFSource(
                                h5f, t.name)
                            tab.mdh = mdh

                            self.namespace[key_prefix + t.name] = tab

                        elif isinstance(t, tables.EArray):
                            # load using ImageStack._loadh5, which finds metdata
                            im = ImageStack(filename=filename, haveGUI=False)
                            # assume image is the main table in the file and give it the named key
                            self.namespace[key] = im

        elif extension == '.csv':
            logger.error('loading .csv not supported yet')
            raise NotImplementedError
        elif extension in ['.xls', '.xlsx']:
            logger.error('loading .xls not supported yet')
            raise NotImplementedError
        else:
            self.namespace[key] = ImageStack(filename=filename, haveGUI=False)

    @property
    def pipeline_view(self):
        import wx
        if wx.GetApp() is None:
            return None
        else:
            from traitsui.api import View, ListEditor, InstanceEditor, Item
            #v = tu.View(tu.Item('modules', editor=tu.ListEditor(use_notebook=True, view='pipeline_view'), style='custom', show_label=False),
            #            buttons=['OK', 'Cancel'])

            return View(Item('modules',
                             editor=ListEditor(
                                 style='custom',
                                 editor=InstanceEditor(view='pipeline_view'),
                                 mutable=False),
                             style='custom',
                             show_label=False),
                        buttons=['OK', 'Cancel'])

    def to_svg(self):
        from . import recipeLayout
        return recipeLayout.to_svg(self.dependancyGraph())

    def _repr_svg_(self):
        """ Make us look pretty in Jupyter"""
        return self.to_svg()
Ejemplo n.º 6
0
class ImageRenderLayer(EngineLayer):
    """
    Layer for viewing images.
    """
    # properties to show in the GUI. Note that we also inherit 'visible' from BaseLayer
    cmap = Enum(*cm.cmapnames, default='gray', desc='Name of colourmap used to colour faces')
    clim = ListFloat([0, 1], desc='How our data should be scaled prior to colour mapping')
    alpha = Float(1.0, desc='Tranparency')
    method = Enum(*ENGINES.keys(), desc='Method used to display image')
    dsname = CStr('output', desc='Name of the datasource within the pipeline to use as an image')
    channel = Int(0)
    slice = Int(0)
    z_pos = Float(0)
    _datasource_choices = List()
    _datasource_keys = List()

    def __init__(self, pipeline, method='image', dsname='', display_opts=None, context=None, **kwargs):
        EngineLayer.__init__(self, context=context, **kwargs)
        self._pipeline = pipeline
        self.engine = None
        self.cmap = 'gray'

        self._bbox = None
        self._do = display_opts #a dh5view display_options instance - if provided, this over-rides the the clim, cmap properties
        
        self._im_key = None

        # define a signal so that people can be notified when we are updated (currently used to force a redraw when
        # parameters change)
        self.on_update = dispatch.Signal()

        # define responses to changes in various traits
        #self.on_trait_change(self._update, 'vertexColour')
        self.on_trait_change(lambda: self.on_update.send(self), 'visible')
        self.on_trait_change(self.update, 'cmap, clim, alpha, dsname')
        self.on_trait_change(self._set_method, 'method')

        # update any of our traits which were passed as command line arguments
        self.set(**kwargs)

        # update datasource and method
        self.dsname = dsname
        if self.method == method:
            #make sure we still call _set_method even if we start with the default method
            self._set_method()
        else:
            self.method = method

        # if we were given a pipeline, connect ourselves to the onRebuild signal so that we can automatically update
        # ourselves
        if (not self._pipeline is None) and hasattr(pipeline, 'onRebuild'):
            self._pipeline.onRebuild.connect(self.update)

    @property
    def datasource(self):
        """
        Return the datasource we are connected to (does not go through the pipeline for triangles_mesh).
        """
        try:
            return self._pipeline.get_layer_data(self.dsname)
        except AttributeError:
            #fallback if pipeline is a dictionary
            return self._pipeline[self.dsname]
        #return self.datasource
    
    @property
    def _ds_class(self):
        # from PYME.experimental import triangle_mesh
        from PYME.IO import image
        return image.ImageStack

    def _set_method(self):
        self.engine = ENGINES[self.method](self._context)
        self.update()


    # def _update(self, *args, **kwargs):
    #     #pass
    #     cdata = self._get_cdata()
    #     self.clim = [float(cdata.min()), float(cdata.max())]
    #     self.update(*args, **kwargs)

    def update(self, *args, **kwargs):
        try:
            self._datasource_choices = [k for k, v in self._pipeline.dataSources.items() if isinstance(v, self._ds_class)]
        except AttributeError:
            self._datasource_choices = [k for k, v in self._pipeline.items() if
                                        isinstance(v, self._ds_class)]
        
        if not (self.engine is None or self.datasource is None):
            print('lw update')
            self.update_from_datasource(self.datasource)
            self.on_update.send(self)

    @property
    def bbox(self):
        return self._bbox
    
    def sync_to_display_opts(self, do=None):
        if (do is None):
            if not (self._do is None):
                do = self._do
            else:
                return

        o = do.Offs[self.channel]
        g = do.Gains[self.channel]
        clim = [o, o + 1.0 / g]

        cmap = do.cmaps[self.channel].name
        visible = do.show[self.channel]
        
        self.set(clim=clim, cmap=cmap, visible=visible)
        

    def update_from_datasource(self, ds):
        """

        Parameters
        ----------
        ds :
            PYME.IO.image.ImageStack object

        Returns
        -------
        None
        """

        
        #if self._do is not None:
            # Let display options (if provied) over-ride our settings (TODO - is this the right way to do this?)
        #    o = self._do.Offs[self.channel]
        #    g = self._do.Gains[self.channel]
        #    clim = [o, o + 1.0/g]
            #self.clim = clim
            
        #    cmap = self._do.cmaps[self.channel]
            #self.visible = self._do.show[self.channel]
        #else:
        
        clim = self.clim
        cmap = getattr(cm, self.cmap)
            
        alpha = float(self.alpha)
        
        c0, c1 = clim
        
        im_key = (self.dsname, self.slice, self.channel)
        
        if not self._im_key == im_key:
            self._im_key = im_key
            self._im = ds.data[:,:,self.slice, self.channel].astype('f4')# - c0)/(c1-c0)
        
            x0, y0, x1, y1, _, _ = ds.imgBounds.bounds

            self._bbox = np.array([x0, y0, 0, x1, y1, 0])
        
            self._bounds = [x0, y0, x1, y1]
            
        self._alpha = alpha
        self._color_map = cmap
        self._color_limit = clim

    def get_color_map(self):
        return self._color_map

    @property
    def colour_map(self):
        return self._color_map

    def get_color_limit(self):
        return self._color_limit
    
    def _get_cdata(self):
        return self._im.ravel()[::20]

    @property
    def default_view(self):
        from traitsui.api import View, Item, Group, InstanceEditor, EnumEditor
        from PYME.ui.custom_traits_editors import HistLimitsEditor, CBEditor

        return View([Group([Item('dsname', label='Data', editor=EnumEditor(name='_datasource_choices')), ]),
                     #Item('method'),
                     Group([Item('clim', editor=HistLimitsEditor(data=self._get_cdata), show_label=False), ]),
                     Group([Item('cmap', label='LUT'),
                            Item('alpha', visible_when='method in ["flat", "tessel"]')
                            ])
                     ], )
        # buttons=['OK', 'Cancel'])

    def default_traits_view(self):
        return self.default_view
Ejemplo n.º 7
0
class TriangleRenderLayer(EngineLayer):
    """
    Layer for viewing triangle meshes.
    """
    # properties to show in the GUI. Note that we also inherit 'visible' from BaseLayer
    vertexColour = CStr('constant', desc='Name of variable used to colour our points')
    cmap = Enum(*cm.cmapnames, default='gist_rainbow', desc='Name of colourmap used to colour faces')
    clim = ListFloat([0, 1], desc='How our variable should be scaled prior to colour mapping')
    alpha = Float(1.0, desc='Face tranparency')
    method = Enum(*ENGINES.keys(), desc='Method used to display faces')
    normal_mode = Enum(['Per vertex', 'Per face'])
    dsname = CStr('output', desc='Name of the datasource within the pipeline to use as a source of triangles (should be a TriangularMesh object)')
    _datasource_choices = List()
    _datasource_keys = List()

    def __init__(self, pipeline, method='wireframe', dsname='', context=None, **kwargs):
        EngineLayer.__init__(self, context=context, **kwargs)
        self._pipeline = pipeline
        self.engine = None
        self.cmap = 'gist_rainbow'

        self.x_key = 'x'  # TODO - make these traits?
        self.y_key = 'y'
        self.z_key = 'z'

        self.xn_key = 'xn'
        self.yn_key = 'yn'
        self.zn_key = 'zn'

        self._bbox = None

        # define a signal so that people can be notified when we are updated (currently used to force a redraw when
        # parameters change)
        self.on_update = dispatch.Signal()

        # define responses to changes in various traits
        self.on_trait_change(self._update, 'vertexColour')
        self.on_trait_change(lambda: self.on_update.send(self), 'visible')
        self.on_trait_change(self.update, 'cmap, clim, alpha, dsname, normal_mode')
        self.on_trait_change(self._set_method, 'method')

        # update any of our traits which were passed as command line arguments
        self.set(**kwargs)

        # update datasource and method
        self.dsname = dsname
        if self.method == method:
            #make sure we still call _set_method even if we start with the default method
            self._set_method()
        else:
            self.method = method

        # if we were given a pipeline, connect ourselves to the onRebuild signal so that we can automatically update
        # ourselves
        if not self._pipeline is None:
            self._pipeline.onRebuild.connect(self.update)

    @property
    def datasource(self):
        """
        Return the datasource we are connected to (does not go through the pipeline for triangles_mesh).
        """
        return self._pipeline.get_layer_data(self.dsname)
        #return self.datasource
    
    @property
    def _ds_class(self):
        # from PYME.experimental import triangle_mesh
        from PYME.experimental import _triangle_mesh as triangle_mesh
        return triangle_mesh.TrianglesBase

    def _set_method(self):
        self.engine = ENGINES[self.method](self._context)
        self.update()

    def _get_cdata(self):
        try:
            cdata = self.datasource[self.vertexColour]
        except (KeyError, TypeError):
            cdata = np.array([0, 1])

        return cdata

    def _update(self, *args, **kwargs):
        #pass
        cdata = self._get_cdata()
        self.clim = [float(cdata.min()), float(cdata.max())]
        self.update(*args, **kwargs)

    def update(self, *args, **kwargs):
        self._datasource_choices = [k for k, v in self._pipeline.dataSources.items() if isinstance(v, self._ds_class)]
        
        if not self.datasource is None:
            dks = ['constant',]
            if hasattr(self.datasource, 'keys'):
                 dks = dks + sorted(self.datasource.keys())
            self._datasource_keys = dks
        
        if not (self.engine is None or self.datasource is None):
            print('lw update')
            self.update_from_datasource(self.datasource)
            self.on_update.send(self)

    @property
    def bbox(self):
        return self._bbox

    def update_from_datasource(self, ds):
        """
        Pulls vertices/normals from a binary STL file. See PYME.IO.FileUtils.stl for more info. Calls update_data on the input.

        Parameters
        ----------
        ds :
            PYME.experimental.triangular_mesh.TriangularMesh object

        Returns
        -------
        None
        """
        #t = ds.vertices[ds.faces]
        #n = ds.vertex_normals[ds.faces]
        
        x, y, z = ds.vertices[ds.faces].reshape(-1, 3).T
        
        if self.normal_mode == 'Per vertex':
            xn, yn, zn = ds.vertex_normals[ds.faces].reshape(-1, 3).T
        else:
            xn, yn, zn = np.repeat(ds.face_normals.T, 3, axis=1)
            
        if self.vertexColour in ['', 'constant']:
            c = np.ones(len(x))
            clim = [0, 1]
        #elif self.vertexColour == 'vertex_index':
        #    c = np.arange(0, len(x))
        else:
            c = ds[self.vertexColour][ds.faces].ravel()
            clim = self.clim

        cmap = getattr(cm, self.cmap)
        alpha = float(self.alpha)

        # Do we have coordinates? Concatenate into vertices.
        if x is not None and y is not None and z is not None:
            vertices = np.vstack((x.ravel(), y.ravel(), z.ravel()))
            self._vertices = vertices.T.ravel().reshape(len(x.ravel()), 3)

            if not xn is None:
                self._normals = np.vstack((xn.ravel(), yn.ravel(), zn.ravel())).T.ravel().reshape(len(x.ravel()), 3)
            else:
                self._normals = -0.69 * np.ones(self._vertices.shape)

            self._bbox = np.array([x.min(), y.min(), z.min(), x.max(), y.max(), z.max()])
        else:
            self._bbox = None

        # TODO: This temporarily sets all triangles to the color red. User should be able to select color.
        if c is None:
            c = np.ones(self._vertices.shape[0]) * 255  # vector of pink
            
        

        if clim is not None and c is not None and cmap is not None:
            cs_ = ((c - clim[0]) / (clim[1] - clim[0]))
            cs = cmap(cs_)

            if self.method in ['flat', 'tessel']:
                alpha = cs_ * alpha
            
            cs[:, 3] = alpha
            
            if self.method == 'tessel':
                cs = np.power(cs, 0.333)

            self._colors = cs.ravel().reshape(len(c), 4)
        else:
            # cs = None
            if not self._vertices is None:
                self._colors = np.ones((self._vertices.shape[0], 4), 'f')
            
        self._alpha = alpha
        self._color_map = cmap
        self._color_limit = clim


    def get_vertices(self):
        return self._vertices

    def get_normals(self):
        return self._normals

    def get_colors(self):
        return self._colors

    def get_color_map(self):
        return self._color_map

    @property
    def colour_map(self):
        return self._color_map

    def get_color_limit(self):
        return self._color_limit

    @property
    def default_view(self):
        from traitsui.api import View, Item, Group, InstanceEditor, EnumEditor
        from PYME.ui.custom_traits_editors import HistLimitsEditor, CBEditor

        return View([Group([Item('dsname', label='Data', editor=EnumEditor(name='_datasource_choices')), ]),
                     Item('method'),
                     Item('normal_mode', visible_when='method=="shaded"'),
                     Item('vertexColour', editor=EnumEditor(name='_datasource_keys'), label='Colour'),
                     Group([Item('clim', editor=HistLimitsEditor(data=self._get_cdata), show_label=False), ], visible_when='vertexColour != "constant"'),
                     Group([Item('cmap', label='LUT'),
                            Item('alpha', visible_when='method in ["flat", "tessel"]')
                            ])
                     ], )
        # buttons=['OK', 'Cancel'])

    def default_traits_view(self):
        return self.default_view
class InterpolatePSF(ModuleBase):
    """
        Interpolate PSF with RBF.
        Very stupid. Very slow. Performed on local pixels and combine by tiling.
        Only uses the first color channel
    """
    
    inputName = Input('input')
    rbf_radius = Float(250.0)
    target_voxelsize = List(Float, [100., 100., 100.])

    output_images = Output('psf_interpolated')
    
    def execute(self, namespace):        
        ims = namespace[self.inputName]
        data = ims.data[:,:,:,0]
        
        dims_original = list()
        voxelsize = [ims.mdh.voxelsize.x, ims.mdh.voxelsize.y, ims.mdh.voxelsize.z]
        for dim, dim_len in enumerate(data.shape):
            d = np.linspace(0, dim_len-1, dim_len) * voxelsize[dim] * 1E3
            d -= d.mean()        
            dims_original.append(d)
        X, Y, Z = np.meshgrid(*dims_original, indexing='ij')
        
        dims_interpolated = list()
        for dim, dim_len in enumerate(data.shape):
            tar_len = int(np.ceil((voxelsize[dim]*1E3 * dim_len) / self.target_voxelsize[dim]))
            d = np.arange(tar_len) * self.target_voxelsize[dim]
            d -= d.mean()
            dims_interpolated.append(d)
        
        X_interp, Y_interp, Z_interp = np.meshgrid(*dims_interpolated, indexing='ij')
        pts_interp = zip(*[X_interp.flatten(), Y_interp.flatten(), Z_interp.flatten()])
        
        results = np.zeros(X_interp.size)
        for i, pt in enumerate(pts_interp):
            results[i] = self.InterpolateAt(pt, X, Y, Z, data[:,:,:])
            if i % 100 == 0:
                print("{} out of {} completed.".format(i, results.shape[0]))
        
        results = results.reshape(len(dims_interpolated[0]), len(dims_interpolated[1]), len(dims_interpolated[2]))
#        return results
        new_mdh = None
        try:
            new_mdh = MetaDataHandler.NestedClassMDHandler(ims.mdh)            
            new_mdh["voxelsize.x"] = self.target_voxelsize[0] * 1E-3
            new_mdh["voxelsize.y"] = self.target_voxelsize[1] * 1E-3
            new_mdh["voxelsize.z"] = self.target_voxelsize[2] * 1E-3
            new_mdh['Interpolation.Method'] = 'RBF'
            new_mdh['Interpolation.RbfRadius'] = self.rbf_radius
        except Exception as e:
            print(e)
        namespace[self.output_images] = ImageStack(data=results, mdh=new_mdh)
        
    def InterpolateAt(self, pt, X, Y, Z, data, radius=250.):
        X_subset, Y_subset, Z_subset, data_subset = self.GetPointsInNeighbourhood(pt, X, Y, Z, data, radius)
        rbf = interpolate.Rbf(X_subset, Y_subset, Z_subset, data_subset, function="cubic", smooth=1E3)#, norm=euclidean_norm_numpy)
    #     print pt
        return rbf(*pt)
    
    def GetPointsInNeighbourhood(self, center, X, Y, Z, data, radius=250.):
        distance = np.sqrt((X-center[0])**2 + (Y-center[1])**2 + (Z-center[2])**2)
    #     print distance.shape
        mask = distance < 250.
        
        return X[mask], Y[mask], Z[mask], data[mask]
class AveragePSF(ModuleBase):
    """
        Input stacks of PSF and return the (normalized) average PSF.
        Additional filter based on max error/residual between image and averaged image.
    """
    
    inputName = Input('psf_aligned')
    normalize_intensity = Bool(False)
#    normalize_z = Bool(False)
    output_var_image = Output('psf_var')
#    smoothing_method = Enum(['RBF', 'Gaussian'])
#    output_var_image_norm = Output('psf_var_norm')
    gaussian_filter = List(Float, [0, 0, 0], 3, 3)
    residual_threshold = Float(0.1)
    output_images = Output('psf_combined')
    
    def execute(self, namespace):        
        ims = namespace[self.inputName]
        psf_raw = ims.data[:,:,:,:]
        
        # always normalize first, since needed for statistics
        psf_raw_norm = psf_raw.copy()
#        if self.normalize_intensity == True:
        psf_raw_norm /= psf_raw_norm.max(axis=(0,1,2), keepdims=True)
        psf_raw_norm /= psf_raw_norm.sum(axis=(0,1), keepdims=True)
        psf_raw_norm -= psf_raw_norm.min()
        psf_raw_norm /= psf_raw_norm.max()
        
        residual_max = np.abs(psf_raw_norm - psf_raw_norm.mean(axis=3, keepdims=True)).max(axis=(0,1,2))
        print(residual_max)
        mask = residual_max < self.residual_threshold
        print "images ignore: {}".format(np.argwhere(~mask)[:,0])
        print mask
        psf_raw_norm = psf_raw_norm[:,:,:,mask]
        print(psf_raw_norm.shape)
        psf_raw_norm -= psf_raw_norm.min()
        psf_raw_norm /= psf_raw_norm.max()
        
        psf_var = psf_raw_norm.var(axis=3)  
        
#        psf_var_norm = psf_var / psf_combined.mean(axis=3)
        namespace[self.output_var_image] = ImageStack(psf_var, mdh=ims.mdh)
#        namespace[self.output_var_image_norm] = ImageStack(np.nan_to_num(psf_var_norm), mdh=ims.mdh)
        
        # if requested not to normalize, revert back to original data
        if not self.normalize_intensity:
            psf_raw_norm = psf_raw.copy()[:,:,:,mask]
        
        psf_combined = psf_raw_norm.mean(axis=3)
        psf_combined -= psf_combined.min()
        psf_combined /= psf_combined.max()
        
#        if self.smoothing_method == 'RBF':
#            dims = [np.arange(i) for i in psf_combined.shape]
            
#        elif self.smoothing_method == 'Gaussian' and
        if np.any(np.asarray(self.gaussian_filter)!=0):
            psf_processed = ndimage.gaussian_filter(psf_combined, self.gaussian_filter)
        else:
            psf_processed = psf_combined
            
        psf_processed -= psf_processed.min()
        psf_processed /= psf_processed.max()
        
        new_mdh = None
        try:
            new_mdh = MetaDataHandler.NestedClassMDHandler(ims.mdh)
            new_mdh["PSFExtraction.GaussianFilter"] = self.gaussian_filter
            new_mdh["PSFExtraction.NormalizeIntensity"] = self.normalize_intensity            
        except Exception as e:
            print(e)
        namespace[self.output_images] = ImageStack(psf_processed, mdh=new_mdh)
        
        if True:
            fig, axes = pyplot.subplots(2, 3, figsize=(9,6))
            axes[0,0].set_title('X')
            axes[0,0].plot(psf_raw_norm[:, psf_raw_norm.shape[1]//2, psf_raw_norm.shape[2]//2, :])
            axes[1,0].plot(psf_combined[:, psf_combined.shape[1]//2, psf_combined.shape[2]//2], lw=1, color='red')
            axes[1,0].plot(psf_processed[:, psf_processed.shape[1]//2, psf_processed.shape[2]//2], lw=1, ls='--', color='black')
            axes[0,1].set_title('Y')
            axes[0,1].plot(psf_raw_norm[psf_raw_norm.shape[0]//2, :, psf_raw_norm.shape[2]//2, :])
            axes[1,1].plot(psf_combined[psf_combined.shape[0]//2, :, psf_combined.shape[2]//2], lw=1, color='red')
            axes[1,1].plot(psf_processed[psf_processed.shape[0]//2, :, psf_processed.shape[2]//2], lw=1, ls='--', color='black')
            axes[0,2].set_title('Z')
            axes[0,2].plot(psf_raw_norm[psf_raw_norm.shape[0]//2, psf_raw_norm.shape[1]//2, :, :])
            axes[1,2].plot(psf_combined[psf_combined.shape[0]//2, psf_combined.shape[1]//2, :], lw=1, color='red')
            axes[1,2].plot(psf_processed[psf_processed.shape[0]//2, psf_processed.shape[1]//2, :], lw=1, ls='--', color='black')
            
            fig.tight_layout()
            
            fig, ax = pyplot.subplots(1, 1, figsize=(4,3))
            ax.hist(residual_max, bins=20)
            ax.axvline(self.residual_threshold, color='red', ls='--')
class CropPSF(ModuleBase):
    """
        Crops out PSF based on positions given.
        Built-in filter by flattened index.
        Built-in filter for removing multiple peaked data.
        Filters work on flatten X, Y images
        Stacked in the 4 dimension
    """
    
    inputName = Input('input')
    input_pos = Input('psf_pos')
    
    ignore_pos = List(Int, [])
    threshold_reject = Float(0.5)
    com_reject = Float(2.0)
    
    half_roi_x = Int(20)
    half_roi_y = Int(20)
    half_roi_z = Int(60)
    
    output_images = Output('psf_cropped')
    
    def execute(self, namespace):
        ims = namespace[self.inputName]
        psf_pos = namespace[self.input_pos]
        
        res = np.zeros((self.half_roi_x*2+1, self.half_roi_y*2+1, self.half_roi_z*2+1, sum([ar.shape[0] for ar in psf_pos])))
        
        mask = np.ones(res.shape[3], dtype=bool)
        counter = 0
        for c in np.arange(ims.data.shape[3]):
            for i in np.arange(len(psf_pos[c])):
#                print psf_pos[c][i][:3]
                x, y, z = psf_pos[c][i][:3]
                x_slice = slice(x-self.half_roi_x, x+self.half_roi_x+1)
                y_slice = slice(y-self.half_roi_y, y+self.half_roi_y+1)
                z_slice = slice(z-self.half_roi_z, z+self.half_roi_z+1)
                res[:, :, :, counter] = ims.data[x_slice, y_slice, z_slice, c].squeeze()
                
                crop_flatten = res[:, :, :, counter].mean(2)
                failed = False
                labeled_image, labeled_counts = ndimage.label(crop_flatten > crop_flatten.max() * self.threshold_reject)
                if labeled_counts > 1:
                    failed = True
                else:
                    com = np.asarray(ndimage.center_of_mass(crop_flatten, labeled_image, 1))
                    img_center = np.asarray([(s-1)*0.5 for s in labeled_image.shape])
                    dist = np.linalg.norm(com - img_center)
#                    print(com, img_center, dist)
                    if dist > self.com_reject:
                        failed = True
                    
                if failed and counter not in self.ignore_pos:
                    self.ignore_pos.append(counter)                    
                
                counter += 1

        # To do: add metadata
#        mdh['ImageType=']='PSF'        
        print "images ignore: {}".format(self.ignore_pos)
        mask[self.ignore_pos] = False
        
        new_mdh = None
        try:
            new_mdh = MetaDataHandler.NestedClassMDHandler(ims.mdh)
            new_mdh["ImageType"] = 'PSF'
            if not "PSFExtraction.SourceFilenames" in new_mdh.keys():
                new_mdh["PSFExtraction.SourceFilenames"] = ims.filename
        except Exception as e:
            print(e)
            
        namespace[self.output_images] = ImageStack(data=res[:,:,:,mask], mdh=new_mdh)
Ejemplo n.º 11
0
class TrackRenderLayer(EngineLayer):
    """
    A layer for viewing tracking data

    """

    # properties to show in the GUI. Note that we also inherit 'visible' from BaseLayer
    vertexColour = CStr('', desc='Name of variable used to colour our points')
    cmap = Enum(*cm.cmapnames,
                default='gist_rainbow',
                desc='Name of colourmap used to colour points')
    clim = ListFloat(
        [0, 1],
        desc='How our variable should be scaled prior to colour mapping')
    alpha = Float(1.0, desc='Tranparency')
    line_width = Float(1.0, desc='Track line width')
    method = Enum(*ENGINES.keys(), desc='Method used to display tracks')
    clump_key = CStr('clumpIndex',
                     desc="Name of column containing the track identifier")
    dsname = CStr(
        'output',
        desc=
        'Name of the datasource within the pipeline to use as a source of points'
    )
    _datasource_keys = List()
    _datasource_choices = List()

    def __init__(self,
                 pipeline,
                 method='tracks',
                 dsname='',
                 context=None,
                 **kwargs):
        EngineLayer.__init__(self, context=context, **kwargs)
        self._pipeline = pipeline
        self.engine = None
        self.cmap = 'gist_rainbow'

        self.x_key = 'x'  #TODO - make these traits?
        self.y_key = 'y'
        self.z_key = 'z'

        self._bbox = None

        # define a signal so that people can be notified when we are updated (currently used to force a redraw when
        # parameters change)
        self.on_update = dispatch.Signal()

        # define responses to changes in various traits
        self.on_trait_change(self._update, 'vertexColour')
        self.on_trait_change(lambda: self.on_update.send(self), 'visible')
        self.on_trait_change(self.update,
                             'cmap, clim, alpha, dsname, clump_key')
        self.on_trait_change(self._set_method, 'method')

        # update any of our traits which were passed as command line arguments
        self.set(**kwargs)

        # update datasource name and method
        #logger.debug('Setting dsname and method')
        self.dsname = dsname
        self.method = method

        self._set_method()

        # if we were given a pipeline, connect ourselves to the onRebuild signal so that we can automatically update
        # ourselves
        if not self._pipeline is None:
            self._pipeline.onRebuild.connect(self.update)

    @property
    def datasource(self):
        """
        Return the datasource we are connected to (through our dsname property).
        """
        return self._pipeline.get_layer_data(self.dsname)

    def _set_method(self):
        #logger.debug('Setting layer method to %s' % self.method)
        self.engine = ENGINES[self.method](self._context)
        self.update()

    def _get_cdata(self):
        try:
            if isinstance(self.datasource, ClumpManager):
                cdata = []
                for track in self.datasource.all:
                    cdata.extend(track[self.vertexColour])
                cdata = np.array(cdata)
            else:
                # Assume tabular dataset
                cdata = self.datasource[self.vertexColour]
        except KeyError:
            cdata = np.array([0, 1])

        return cdata

    def _update(self, *args, **kwargs):
        cdata = self._get_cdata()
        self.clim = [float(np.nanmin(cdata)), float(np.nanmax(cdata))]
        #self.update(*args, **kwargs)

    def update(self, *args, **kwargs):
        print('lw update')
        self._datasource_choices = self._pipeline.layer_data_source_names
        if not self.datasource is None:
            if isinstance(self.datasource, ClumpManager):
                # Grab the keys from the first Track in the ClumpManager
                self._datasource_keys = sorted(self.datasource[0].keys())
            else:
                # Assume we have a tabular data source
                self._datasource_keys = sorted(self.datasource.keys())

        if not (self.engine is None or self.datasource is None):
            self.update_from_datasource(self.datasource)
            self.on_update.send(self)

    @property
    def bbox(self):
        return self._bbox

    def update_from_datasource(self, ds):
        if isinstance(ds, ClumpManager):
            x = []
            y = []
            z = []
            c = []
            self.clumpSizes = []

            # Copy data from tracks. This is already in clump order
            # thanks to ClumpManager
            for track in ds.all:
                x.extend(track['x'])
                y.extend(track['y'])
                z.extend(track['z'])
                self.clumpSizes.append(track.nEvents)

                if not self.vertexColour == '':
                    c.extend(track[self.vertexColour])
                else:
                    c.extend([0 for i in track['x']])

            x = np.array(x)
            y = np.array(y)
            z = np.array(z)
            c = np.array(c)

            # print(x,y,z,c)
            # print(x.shape,y.shape,z.shape,c.shape)

        else:
            # Assume tabular data source
            x, y = ds[self.x_key], ds[self.y_key]

            if not self.z_key is None:
                try:
                    z = ds[self.z_key]
                except KeyError:
                    z = 0 * x
            else:
                z = 0 * x

            if not self.vertexColour == '':
                c = ds[self.vertexColour]
            else:
                c = 0 * x

            # Work out clump start and finish indices
            # TODO - optimize / precompute????
            ci = ds[self.clump_key]

            NClumps = int(ci.max())

            clist = [[] for i in range(NClumps)]
            for i, cl_i in enumerate(ci):
                clist[int(cl_i - 1)].append(i)

            # This and self.clumpStarts are class attributes for
            # compatibility with the old Track rendering layer,
            # PYME.LMVis.gl_render3D.TrackLayer
            self.clumpSizes = [len(cl_i) for cl_i in clist]

            #reorder x, y, z, c in clump order
            I = np.hstack([np.array(cl) for cl in clist]).astype(np.int)

            x = x[I]
            y = y[I]
            z = z[I]
            c = c[I]

        self.clumpStarts = np.cumsum([
            0,
        ] + self.clumpSizes)

        #do normal vertex stuff
        vertices = np.vstack((x.ravel(), y.ravel(), z.ravel()))
        vertices = vertices.T.ravel().reshape(len(x.ravel()), 3)

        self._vertices = vertices
        self._normals = -0.69 * np.ones(vertices.shape)
        self._bbox = np.array(
            [x.min(), y.min(),
             z.min(), x.max(),
             y.max(), z.max()])

        clim = self.clim
        cmap = getattr(cm, self.cmap)

        if clim is not None:
            cs_ = ((c - clim[0]) / (clim[1] - clim[0]))
            cs = cmap(cs_)
            cs[:, 3] = float(self.alpha)

            self._colors = cs.ravel().reshape(len(c), 4)
        else:
            if not vertices is None:
                self._colors = np.ones((vertices.shape[0], 4), 'f')

        self._color_map = cmap
        self._color_limit = clim
        self._alpha = float(self.alpha)

    def get_vertices(self):
        return self._vertices

    def get_normals(self):
        return self._normals

    def get_colors(self):
        return self._colors

    def get_color_map(self):
        return self._color_map

    @property
    def colour_map(self):
        return self._color_map

    def get_color_limit(self):
        return self._color_limit

    @property
    def default_view(self):
        from traitsui.api import View, Item, Group, InstanceEditor, EnumEditor, TextEditor
        from PYME.ui.custom_traits_editors import HistLimitsEditor, CBEditor

        return View([
            Group([
                Item('dsname',
                     label='Data',
                     editor=EnumEditor(name='_datasource_choices')),
            ]),
            Item('method'),
            Item(
                'vertexColour',
                editor=EnumEditor(name='_datasource_keys'),
                label='Colour',
                visible_when='cmap not in ["R", "G", "B", "C", "M","Y", "K"]'),
            Group(
                [
                    Item('clim',
                         editor=HistLimitsEditor(data=self._get_cdata,
                                                 update_signal=self.on_update),
                         show_label=False),
                ],
                visible_when='cmap not in ["R", "G", "B", "C", "M","Y", "K"]'),
            Group(
                Item('cmap', label='LUT'),
                Item('alpha',
                     editor=TextEditor(auto_set=False,
                                       enter_set=True,
                                       evaluate=float)), Item('line_width'))
        ])
        #buttons=['OK', 'Cancel'])

    def default_traits_view(self):
        return self.default_view
Ejemplo n.º 12
0
class LoadDriftandInterp(ModuleBase):
    """
    Loads drift data from file(s) and use them to create a spline interpolator (``scipy.interpolate.UnivariateSpline``).
        
    Inputs
    ------
    input_dummy : None
       Blank input. Required to run correctly.
    
    Outputs
    -------
    output_drift_interpolator :
        Drift interpolator. Returns drift when called with frame number / time.
    output_drift_plot : Plot
        Plot of the original and interpolated drift.
        
    Parameters
    ----------
    load_paths : list of File
        List of files to load.
    degree_of_spline : int
        Degree of the smoothing spline.
    smoothing_factor : float
        Smoothing factor.
    """

    input_dummy = Input('input')  # breaks GUI without this???
    #    load_path = File()
    load_paths = List(File, [""], 1)
    degree_of_spline = Int(3)  # 1 for linear, 3 for cubic
    smoothing_factor = Float(
        -1)  # 0 for no smoothing. set to negative for UnivariateSpline defulat
    #    input_drift_raw = Input('drift_raw')
    output_drift_interpolator = Output('drift_interpolator')
    output_drift_plot = Output('drift_plot')

    #    output_drift_raw= Input('drift_raw')

    def execute(self, namespace):
        spl_array = list()
        t_min = np.inf
        t_max = 0
        tIndexes = list()
        drifts = list()
        for fil in self.load_paths:
            data = np.load(fil)
            tIndex = data['tIndex']
            t_min = min(t_min, tIndex[0])
            t_max = max(t_max, tIndex[-1])
            drift = data['drift']

            tIndexes.append(tIndex)
            drifts.append(drift)

            spl = interpolate_drift(tIndex, drift, self.degree_of_spline,
                                    self.smoothing_factor)
            spl_array.append(spl)


#        print(len(spl_array))
#        print(spl_array[0])
        spl_array = zip(*spl_array)

        #        print(len(spl_final))
        #        print(spl_final[0])
        def spl_method(funcs, t):
            return np.sum([f(t) for f in funcs], axis=0)

        spl_combined = list()
        for spl in spl_array:
            #            print(spl)
            #            spl_combined.append(lambda x: np.sum([f(x) for f in spl], axis=0))
            spl_combined.append(partial(spl_method, spl))

        namespace[self.output_drift_interpolator] = spl_combined

        # non essential, only for plotting out drift data
        namespace[self.output_drift_plot] = Plot(
            partial(generate_drift_plot, tIndexes, drifts, spl_combined))
        namespace[self.output_drift_plot].plot()