class PSFSettings(HasTraits): wavelength_nm = Float(700.) NA = Float(1.47) vectorial = Bool(False) zernike_modes = Dict() zernike_modes_lower = Dict() phases = List([0, .5, 1, 1.5]) four_pi = Bool(False) def default_traits_view(self): from traitsui.api import View, Item #from PYME.ui.custom_traits_editors import CBEditor return View(Item(name='wavelength_nm'), Item(name='NA'), Item(name='vectorial'), Item(name='four_pi', label='4Pi'), Item(name='zernike_modes'), Item(name='zernike_modes_lower', visible_when='four_pi==True'), Item(name='phases', visible_when='four_pi==True', label='phases/pi'), resizable=True, buttons=['OK'])
class CombineBeadStacks(ModuleBase): """ Combine multiply bead stacks in the 4th dimension. X, Y, Z must be identical. """ inputName = Input('dummy') files = List(File, ['', ''], 2) cache = File() outputName = Output('bead_images') def execute(self, namespace): ims = ImageStack(filename=self.files[0]) dims = np.asarray(ims.data.shape, dtype=np.long) dims[3] = 0 dtype_ = ims.data[:,0,0,0].dtype mdh = ims.mdh del ims for fil in self.files: ims = ImageStack(filename=fil) dims[3] += ims.data.shape[3] del ims if self.cache != '': raw_data = np.memmap(self.cache, dtype=dtype_, mode='w+', shape=tuple(dims)) else: raw_data = np.zeros(shape=tuple(dims), dtype=dtype_) counter = 0 for fil in self.files: ims = ImageStack(filename=fil) c_len = ims.data.shape[3] data = ims.data[:,:,:,:] data.shape += (1,) * (4 - data.ndim) raw_data[:,:,:,counter:counter+c_len] = data counter += c_len del ims new_mdh = None try: new_mdh = MetaDataHandler.NestedClassMDHandler(mdh) new_mdh["PSFExtraction.SourceFilenames"] = self.files except Exception as e: print(e) namespace[self.outputName] = ImageStack(data=raw_data, mdh=new_mdh)
class Binning(CacheCleanupModule): """ Downsample 3D data (mean). X, Y pixels that does't fill a full bin are dropped. Pixels in the 3rd dimension can have a partially filled bin. Inputs ------ inputName : ImageStack Outputs ------- outputName : ImageStack Parameters ---------- x_start : int Starting index in x. x_end : Float Stopping index in x. y_start : Float Starting index in y. y_end : Float Stopping index in y. binsize : Float Bin size. cache_bin : File Use file as disk cache if provided. """ inputName = Input('input') x_start = Int(0) x_end = Int(-1) y_start = Int(0) y_end = Int(-1) # z_start = Int(0) # z_end = Int(-1) binsize = List([1, 1, 1], minlen=3, maxlen=3) cache_bin = File("binning_cache_2.bin") outputName = Output('binned_image') def _execute(self, namespace): self._start_time = time.time() ims = namespace[self.inputName] binsize = np.asarray(self.binsize, dtype=np.int) # print (binsize) # unconventional, end stop in inclusive x_slice = np.arange(ims.data.shape[0] + 1)[slice( self.x_start, self.x_end, 1)] y_slice = np.arange(ims.data.shape[1] + 1)[slice( self.y_start, self.y_end, 1)] x_slice = x_slice[:x_slice.shape[0] // binsize[0] * binsize[0]] y_slice = y_slice[:y_slice.shape[0] // binsize[1] * binsize[1]] # print x_slice, len(x_slice) # print y_slice, len(y_slice) bincounts = np.asarray([ len(x_slice) // binsize[0], len(y_slice) // binsize[1], -(-ims.data.shape[2] // binsize[2]) ], dtype=np.long) x_slice_ind = slice(x_slice[0], x_slice[-1] + 1) y_slice_ind = slice(y_slice[0], y_slice[-1] + 1) # print (bincounts) new_shape = np.stack([bincounts, binsize], -1).flatten() # print(new_shape) # need to wrap this to work for multiply color channel images # binned_image = ims.data[:,:,:].reshape(new_shape) dtype = ims.data[:, :, 0].dtype # print bincounts binned_image = np.memmap(self.cache_bin, dtype=dtype, mode='w+', shape=tuple( np.asarray(bincounts, dtype=np.long))) # print binned_image.shape new_shape_one_chunk = new_shape.copy() new_shape_one_chunk[4] = 1 new_shape_one_chunk[5] = -1 # print new_shape_one_chunk progress = 0.2 * ims.data.shape[2] # print for i, f in enumerate(np.arange(0, ims.data.shape[2], binsize[2])): raw_data_chunk = ims.data[x_slice_ind, y_slice_ind, f:f + binsize[2]].squeeze() binned_image[:, :, i] = raw_data_chunk.reshape(new_shape_one_chunk).mean( (1, 3, 5)).squeeze() if (f + binsize[2] >= progress): binned_image.flush() progress += 0.2 * ims.data.shape[2] print("{:.2f} s. Completed binning {} of {} total images.". format(time.time() - self._start_time, min(f + binsize[2], ims.data.shape[2]), ims.data.shape[2])) # print(type(binned_image)) im = ImageStack(binned_image, titleStub=self.outputName) # print(type(im.data)) im.mdh.copyEntriesFrom(ims.mdh) im.mdh['Parent'] = ims.filename try: ### Metadata must be logged correctly for the measured drift to be applicable to the source image im.mdh['voxelsize.x'] *= binsize[0] im.mdh['voxelsize.y'] *= binsize[1] # im.mdh['voxelsize.z'] *= binsize[2] if 'recipe.binning' in im.mdh.keys(): im.mdh['recipe.binning'] = binsize * im.mdh['recipe.binning'] else: im.mdh['recipe.binning'] = binsize except: pass namespace[self.outputName] = im
class PointCloudRenderLayer(EngineLayer): """ A layer for viewing point-cloud data, using one of 3 engines (indicated above) """ # properties to show in the GUI. Note that we also inherit 'visible' from BaseLayer vertexColour = CStr('', desc='Name of variable used to colour our points') point_size = Float(30.0, desc='Rendered size of the points in nm') cmap = Enum(*cm.cmapnames, default='gist_rainbow', desc='Name of colourmap used to colour points') clim = ListFloat( [0, 1], desc='How our variable should be scaled prior to colour mapping') alpha = Float(1.0, desc='Point tranparency') method = Enum(*ENGINES.keys(), desc='Method used to display points') dsname = CStr( 'output', desc= 'Name of the datasource within the pipeline to use as a source of points' ) _datasource_keys = List() _datasource_choices = List() def __init__(self, pipeline, method='points', dsname='', context=None, **kwargs): EngineLayer.__init__(self, context=context, **kwargs) self._pipeline = pipeline self.engine = None self.cmap = 'gist_rainbow' self.x_key = 'x' #TODO - make these traits? self.y_key = 'y' self.z_key = 'z' self.xn_key = 'xn' self.yn_key = 'yn' self.zn_key = 'zn' self._bbox = None # define a signal so that people can be notified when we are updated (currently used to force a redraw when # parameters change) self.on_update = dispatch.Signal() # define responses to changes in various traits self.on_trait_change(self._update, 'vertexColour') self.on_trait_change(lambda: self.on_update.send(self), 'visible') self.on_trait_change(self.update, 'cmap, clim, alpha, dsname, point_size') self.on_trait_change(self._set_method, 'method') # update any of our traits which were passed as command line arguments self.set(**kwargs) # update datasource name and method #logger.debug('Setting dsname and method') self.dsname = dsname self.method = method self._set_method() # if we were given a pipeline, connect ourselves to the onRebuild signal so that we can automatically update # ourselves if not self._pipeline is None: self._pipeline.onRebuild.connect(self.update) @property def datasource(self): """ Return the datasource we are connected to (through our dsname property). """ return self._pipeline.get_layer_data(self.dsname) def _set_method(self): #logger.debug('Setting layer method to %s' % self.method) self.engine = ENGINES[self.method](self._context) self.update() def _get_cdata(self): try: cdata = self.datasource[self.vertexColour] except KeyError: cdata = np.array([0, 1]) return cdata def _update(self, *args, **kwargs): cdata = self._get_cdata() self.clim = [float(cdata.min()), float(cdata.max())] #self.update(*args, **kwargs) def update(self, *args, **kwargs): print('lw update') self._datasource_choices = self._pipeline.layer_data_source_names if not self.datasource is None: self._datasource_keys = sorted(self.datasource.keys()) if not (self.engine is None or self.datasource is None): self.update_from_datasource(self.datasource) self.on_update.send(self) @property def bbox(self): return self._bbox def update_from_datasource(self, ds): x, y = ds[self.x_key], ds[self.y_key] if not self.z_key is None: try: z = ds[self.z_key] except KeyError: z = 0 * x else: z = 0 * x if not self.vertexColour == '': c = ds[self.vertexColour] else: c = 0 * x if self.xn_key in ds.keys(): xn, yn, zn = ds[self.xn_key], ds[self.yn_key], ds[self.zn_key] self.update_data(x, y, z, c, cmap=getattr(cm, self.cmap), clim=self.clim, alpha=self.alpha, xn=xn, yn=yn, zn=zn) else: self.update_data(x, y, z, c, cmap=getattr(cm, self.cmap), clim=self.clim, alpha=self.alpha) def update_data(self, x=None, y=None, z=None, colors=None, cmap=None, clim=None, alpha=1.0, xn=None, yn=None, zn=None): self._vertices = None self._normals = None self._colors = None self._color_map = None self._color_limit = 0 self._alpha = 0 if x is not None and y is not None and z is not None: vertices = np.vstack((x.ravel(), y.ravel(), z.ravel())) vertices = vertices.T.ravel().reshape(len(x.ravel()), 3) if not xn is None: normals = np.vstack( (xn.ravel(), yn.ravel(), zn.ravel())).T.ravel().reshape(len(x.ravel()), 3) else: normals = -0.69 * np.ones(vertices.shape) self._bbox = np.array( [x.min(), y.min(), z.min(), x.max(), y.max(), z.max()]) else: vertices = None normals = None self._bbox = None if clim is not None and colors is not None and clim is not None: cs_ = ((colors - clim[0]) / (clim[1] - clim[0])) cs = cmap(cs_) cs[:, 3] = alpha cs = cs.ravel().reshape(len(colors), 4) else: #cs = None if not vertices is None: cs = np.ones((vertices.shape[0], 4), 'f') else: cs = None color_map = None color_limit = None self.set_values(vertices, normals, cs, cmap, clim, alpha) def set_values(self, vertices=None, normals=None, colors=None, color_map=None, color_limit=None, alpha=None): if vertices is not None: self._vertices = vertices if normals is not None: self._normals = normals if color_map is not None: self._color_map = color_map if colors is not None: self._colors = colors if color_limit is not None: self._color_limit = color_limit if alpha is not None: self._alpha = alpha def get_vertices(self): return self._vertices def get_normals(self): return self._normals def get_colors(self): return self._colors def get_color_map(self): return self._color_map @property def colour_map(self): return self._color_map def get_color_limit(self): return self._color_limit @property def default_view(self): from traitsui.api import View, Item, Group, InstanceEditor, EnumEditor, TextEditor from PYME.ui.custom_traits_editors import HistLimitsEditor, CBEditor return View([ Group([ Item('dsname', label='Data', editor=EnumEditor(name='_datasource_choices')), ]), Item('method'), Item( 'vertexColour', editor=EnumEditor(name='_datasource_keys'), label='Colour', visible_when='cmap not in ["R", "G", "B", "C", "M","Y", "K"]'), Group( [ Item('clim', editor=HistLimitsEditor(data=self._get_cdata, update_signal=self.on_update), show_label=False), ], visible_when='cmap not in ["R", "G", "B", "C", "M","Y", "K"]'), Group( Item('cmap', label='LUT'), Item('alpha', visible_when= "method in ['pointsprites', 'transparent_points']", editor=TextEditor(auto_set=False, enter_set=True, evaluate=float)), Item('point_size', editor=TextEditor(auto_set=False, enter_set=True, evaluate=float))) ]) #buttons=['OK', 'Cancel']) def default_traits_view(self): return self.default_view
class ModuleCollection(HasTraits): modules = List() execute_on_invalidation = Bool(False) def __init__(self, *args, **kwargs): HasTraits.__init__(self, *args, **kwargs) self.namespace = {} # we open hdf files and don't necessarily read their contents into memory - these need to be closed when we # either delete the recipe, or clear the namespace self._open_input_files = [] self.recipe_changed = dispatch.Signal() self.recipe_executed = dispatch.Signal() def invalidate_data(self): if self.execute_on_invalidation: self.execute() def clear(self): self.namespace.clear() def new_output_name(self, stub): count = len([k.startswith(stub) for k in self.namespace.keys()]) if count == 0: return stub else: return '%s_%d' % (stub, count) def dependancyGraph(self): dg = {} #only add items to dependancy graph if they are not already in the namespace #calculated_objects = namespace.keys() for mod in self.modules: #print mod s = mod.inputs try: s.update(dg[mod]) except KeyError: pass dg[mod] = s for op in mod.outputs: #if not op in calculated_objects: dg[op] = { mod, } return dg def reverseDependancyGraph(self): dg = self.dependancyGraph() rdg = {} for k, vs in dg.items(): for v in vs: vdeps = set() try: vdeps = rdg[v] except KeyError: pass vdeps.add(k) rdg[v] = vdeps return rdg def _getAllDownstream(self, rdg, keys): """get all the downstream items which depend on the given key""" downstream = set() next_level = set() for k in keys: try: next_level.update(rdg[k]) except KeyError: pass if len(list(next_level)) > 0: downstream.update(next_level) downstream.update(self._getAllDownstream(rdg, list(next_level))) return downstream def prune_dependencies_from_namespace(self, keys_to_prune, keep_passed_keys=False): rdg = self.reverseDependancyGraph() if keep_passed_keys: downstream = list(self._getAllDownstream(rdg, list(keys_to_prune))) else: downstream = list(keys_to_prune) + list( self._getAllDownstream(rdg, list(keys_to_prune))) #print downstream for dsi in downstream: try: self.namespace.pop(dsi) except KeyError: #the output is not in our namespace, no need to prune pass except AttributeError: #we might not have our namespace defined yet pass def resolveDependencies(self): import toposort #build dependancy graph dg = self.dependancyGraph() #solve the dependency tree return toposort.toposort_flatten(dg, sort=False) def execute(self, **kwargs): #remove anything which is downstream from changed inputs #print self.namespace.keys() for k, v in kwargs.items(): #print k, v try: if not (self.namespace[k] == v): #input has changed print('pruning: ', k) self.prune_dependencies_from_namespace([k]) except KeyError: #key wasn't in namespace previously print('KeyError') pass self.namespace.update(kwargs) exec_order = self.resolveDependencies() for m in exec_order: if isinstance( m, ModuleBase) and not m.outputs_in_namespace(self.namespace): try: m.execute(self.namespace) except: logger.exception("Error in recipe module: %s" % m) raise self.recipe_executed.send_robust(self) if 'output' in self.namespace.keys(): return self.namespace['output'] @classmethod def fromMD(cls, md): c = cls() moduleNames = set([s.split('.')[0] for s in md.keys()]) mc = [] for mn in moduleNames: mod = all_modules[mn]() mod.set(**md[mn]) mc.append(mod) #return cls(modules=mc) c.modules = mc return c def get_cleaned_module_list(self): l = [] for mod in self.modules: #l.append({mod.__class__.__name__: mod.get()}) ct = mod.class_traits() mod_traits_cleaned = {} for k, v in mod.get().items(): if not k.startswith( '_' ): #don't save private data - this is usually used for caching etc .., try: if (not (v == ct[k].default)) or (k.startswith( 'input')) or (k.startswith('output')): #don't save defaults if isinstance(v, dict) and not type(v) == dict: v = dict(v) elif isinstance(v, list) and not type(v) == list: v = list(v) elif isinstance(v, set) and not type(v) == set: v = set(v) mod_traits_cleaned[k] = v except KeyError: # for some reason we have a trait that shouldn't be here pass l.append({module_names[mod.__class__]: mod_traits_cleaned}) return l def toYAML(self): import yaml class MyDumper(yaml.SafeDumper): def represent_mapping(self, tag, value, flow_style=None): return super(MyDumper, self).represent_mapping(tag, value, False) return yaml.dump(self.get_cleaned_module_list(), Dumper=MyDumper) def toJSON(self): import json return json.dumps(self.get_cleaned_module_list()) def _update_from_module_list(self, l): """ Update from a parsed yaml or json list of modules It probably makes no sense to call this directly as the format is pretty wack - a list of dictionarys each with a single entry, but that is how the yaml parses Parameters ---------- l: list List of modules as obtained from parsing a yaml recipe, Each module is a dictionary mapping with a single e.g. [{'Filtering.Filter': {'filters': {'probe': [-0.5, 0.5]}, 'input': 'localizations', 'output': 'filtered'}}] Returns ------- """ mc = [] if l is None: l = [] for mdd in l: mn, md = list(mdd.items())[0] try: mod = all_modules[mn](self) except KeyError: # still support loading old recipes which do not use hierarchical names # also try and support modules which might have moved mod = _legacy_modules[mn.split('.')[-1]](self) mod.set(**md) mc.append(mod) self.modules = mc self.recipe_changed.send_robust(self) self.invalidate_data() @classmethod def _from_module_list(cls, l): """ A factory method which contains the common logic for loading/creating from either yaml or json. Do not call directly""" c = cls() c._update_from_module_list(l) return c @classmethod def fromYAML(cls, data): import yaml l = yaml.load(data) return cls._from_module_list(l) def update_from_yaml(self, data): """ Update from a yaml formatted recipe description Parameters ---------- data: str either yaml formatted text, or the path to a yaml file. Returns ------- None """ import os import yaml if os.path.isfile(data): with open(data) as f: data = f.read() l = yaml.load(data) return self._update_from_module_list(l) @classmethod def fromJSON(cls, data): import json return cls._from_module_list(json.loads(data)) def add_module(self, module): self.modules.append(module) self.recipe_changed.send_robust(self) @property def inputs(self): ip = set() for mod in self.modules: ip.update({k for k in mod.inputs if k.startswith('in')}) return ip @property def outputs(self): op = set() for mod in self.modules: op.update({k for k in mod.outputs if k.startswith('out')}) return op @property def module_outputs(self): op = set() for mod in self.modules: op.update(set(mod.outputs)) return op @property def file_inputs(self): out = [] for mod in self.modules: out += mod.file_inputs return out def save(self, context={}): """ Find all OutputModule instances and call their save methods with the recipe context Parameters ---------- context : dict A context dictionary used to substitute and create variable names. """ for mod in self.modules: if isinstance(mod, OutputModule): mod.save(self.namespace, context) def gather_outputs(self, context={}): """ Find all OutputModule instances and call their generate methods with the recipe context Parameters ---------- context : dict A context dictionary used to substitute and create variable names. """ outputs = [] for mod in self.modules: if isinstance(mod, OutputModule): out = mod.generate(self.namespace, context) if not out is None: outputs.append(out) return outputs def loadInput(self, filename, key='input'): """Load input data from a file and inject into namespace Currently only handles images (anything you can open in dh5view). TODO - extend to other types. """ #modify this to allow for different file types - currently only supports images from PYME.IO import unifiedIO import os extension = os.path.splitext(filename)[1] if extension in ['.h5r', '.h5', '.hdf']: import tables from PYME.IO import MetaDataHandler from PYME.IO import tabular with unifiedIO.local_or_temp_filename(filename) as fn: with tables.open_file(fn, mode='r') as h5f: #make sure our hdf file gets closed key_prefix = '' if key == 'input' else key + '_' try: mdh = MetaDataHandler.NestedClassMDHandler( MetaDataHandler.HDFMDHandler(h5f)) except tables.FileModeError: # Occurs if no metadata is found, since we opened the table in read-mode logger.warning( 'No metadata found, proceeding with empty metadata' ) mdh = MetaDataHandler.NestedClassMDHandler() for t in h5f.list_nodes('/'): # FIXME - The following isinstance tests are not very safe (and badly broken in some cases e.g. # PZF formatted image data, Image data which is not in an EArray, etc ...) # Note that EArray is only used for streaming data! # They should ideally be replaced with more comprehensive tests (potentially based on array or dataset # dimensionality and/or data type) - i.e. duck typing. Our strategy for images in HDF should probably # also be improved / clarified - can we use hdf attributes to hint at the data intent? How do we support # > 3D data? if isinstance(t, tables.VLArray): from PYME.IO.ragged import RaggedVLArray rag = RaggedVLArray( h5f, t.name, copy=True ) #force an in-memory copy so we can close the hdf file properly rag.mdh = mdh self.namespace[key_prefix + t.name] = rag elif isinstance(t, tables.table.Table): # pipe our table into h5r or hdf source depending on the extension tab = tabular.H5RSource( h5f, t.name ) if extension == '.h5r' else tabular.HDFSource( h5f, t.name) tab.mdh = mdh self.namespace[key_prefix + t.name] = tab elif isinstance(t, tables.EArray): # load using ImageStack._loadh5, which finds metdata im = ImageStack(filename=filename, haveGUI=False) # assume image is the main table in the file and give it the named key self.namespace[key] = im elif extension == '.csv': logger.error('loading .csv not supported yet') raise NotImplementedError elif extension in ['.xls', '.xlsx']: logger.error('loading .xls not supported yet') raise NotImplementedError else: self.namespace[key] = ImageStack(filename=filename, haveGUI=False) @property def pipeline_view(self): import wx if wx.GetApp() is None: return None else: from traitsui.api import View, ListEditor, InstanceEditor, Item #v = tu.View(tu.Item('modules', editor=tu.ListEditor(use_notebook=True, view='pipeline_view'), style='custom', show_label=False), # buttons=['OK', 'Cancel']) return View(Item('modules', editor=ListEditor( style='custom', editor=InstanceEditor(view='pipeline_view'), mutable=False), style='custom', show_label=False), buttons=['OK', 'Cancel']) def to_svg(self): from . import recipeLayout return recipeLayout.to_svg(self.dependancyGraph()) def _repr_svg_(self): """ Make us look pretty in Jupyter""" return self.to_svg()
class ImageRenderLayer(EngineLayer): """ Layer for viewing images. """ # properties to show in the GUI. Note that we also inherit 'visible' from BaseLayer cmap = Enum(*cm.cmapnames, default='gray', desc='Name of colourmap used to colour faces') clim = ListFloat([0, 1], desc='How our data should be scaled prior to colour mapping') alpha = Float(1.0, desc='Tranparency') method = Enum(*ENGINES.keys(), desc='Method used to display image') dsname = CStr('output', desc='Name of the datasource within the pipeline to use as an image') channel = Int(0) slice = Int(0) z_pos = Float(0) _datasource_choices = List() _datasource_keys = List() def __init__(self, pipeline, method='image', dsname='', display_opts=None, context=None, **kwargs): EngineLayer.__init__(self, context=context, **kwargs) self._pipeline = pipeline self.engine = None self.cmap = 'gray' self._bbox = None self._do = display_opts #a dh5view display_options instance - if provided, this over-rides the the clim, cmap properties self._im_key = None # define a signal so that people can be notified when we are updated (currently used to force a redraw when # parameters change) self.on_update = dispatch.Signal() # define responses to changes in various traits #self.on_trait_change(self._update, 'vertexColour') self.on_trait_change(lambda: self.on_update.send(self), 'visible') self.on_trait_change(self.update, 'cmap, clim, alpha, dsname') self.on_trait_change(self._set_method, 'method') # update any of our traits which were passed as command line arguments self.set(**kwargs) # update datasource and method self.dsname = dsname if self.method == method: #make sure we still call _set_method even if we start with the default method self._set_method() else: self.method = method # if we were given a pipeline, connect ourselves to the onRebuild signal so that we can automatically update # ourselves if (not self._pipeline is None) and hasattr(pipeline, 'onRebuild'): self._pipeline.onRebuild.connect(self.update) @property def datasource(self): """ Return the datasource we are connected to (does not go through the pipeline for triangles_mesh). """ try: return self._pipeline.get_layer_data(self.dsname) except AttributeError: #fallback if pipeline is a dictionary return self._pipeline[self.dsname] #return self.datasource @property def _ds_class(self): # from PYME.experimental import triangle_mesh from PYME.IO import image return image.ImageStack def _set_method(self): self.engine = ENGINES[self.method](self._context) self.update() # def _update(self, *args, **kwargs): # #pass # cdata = self._get_cdata() # self.clim = [float(cdata.min()), float(cdata.max())] # self.update(*args, **kwargs) def update(self, *args, **kwargs): try: self._datasource_choices = [k for k, v in self._pipeline.dataSources.items() if isinstance(v, self._ds_class)] except AttributeError: self._datasource_choices = [k for k, v in self._pipeline.items() if isinstance(v, self._ds_class)] if not (self.engine is None or self.datasource is None): print('lw update') self.update_from_datasource(self.datasource) self.on_update.send(self) @property def bbox(self): return self._bbox def sync_to_display_opts(self, do=None): if (do is None): if not (self._do is None): do = self._do else: return o = do.Offs[self.channel] g = do.Gains[self.channel] clim = [o, o + 1.0 / g] cmap = do.cmaps[self.channel].name visible = do.show[self.channel] self.set(clim=clim, cmap=cmap, visible=visible) def update_from_datasource(self, ds): """ Parameters ---------- ds : PYME.IO.image.ImageStack object Returns ------- None """ #if self._do is not None: # Let display options (if provied) over-ride our settings (TODO - is this the right way to do this?) # o = self._do.Offs[self.channel] # g = self._do.Gains[self.channel] # clim = [o, o + 1.0/g] #self.clim = clim # cmap = self._do.cmaps[self.channel] #self.visible = self._do.show[self.channel] #else: clim = self.clim cmap = getattr(cm, self.cmap) alpha = float(self.alpha) c0, c1 = clim im_key = (self.dsname, self.slice, self.channel) if not self._im_key == im_key: self._im_key = im_key self._im = ds.data[:,:,self.slice, self.channel].astype('f4')# - c0)/(c1-c0) x0, y0, x1, y1, _, _ = ds.imgBounds.bounds self._bbox = np.array([x0, y0, 0, x1, y1, 0]) self._bounds = [x0, y0, x1, y1] self._alpha = alpha self._color_map = cmap self._color_limit = clim def get_color_map(self): return self._color_map @property def colour_map(self): return self._color_map def get_color_limit(self): return self._color_limit def _get_cdata(self): return self._im.ravel()[::20] @property def default_view(self): from traitsui.api import View, Item, Group, InstanceEditor, EnumEditor from PYME.ui.custom_traits_editors import HistLimitsEditor, CBEditor return View([Group([Item('dsname', label='Data', editor=EnumEditor(name='_datasource_choices')), ]), #Item('method'), Group([Item('clim', editor=HistLimitsEditor(data=self._get_cdata), show_label=False), ]), Group([Item('cmap', label='LUT'), Item('alpha', visible_when='method in ["flat", "tessel"]') ]) ], ) # buttons=['OK', 'Cancel']) def default_traits_view(self): return self.default_view
class TriangleRenderLayer(EngineLayer): """ Layer for viewing triangle meshes. """ # properties to show in the GUI. Note that we also inherit 'visible' from BaseLayer vertexColour = CStr('constant', desc='Name of variable used to colour our points') cmap = Enum(*cm.cmapnames, default='gist_rainbow', desc='Name of colourmap used to colour faces') clim = ListFloat([0, 1], desc='How our variable should be scaled prior to colour mapping') alpha = Float(1.0, desc='Face tranparency') method = Enum(*ENGINES.keys(), desc='Method used to display faces') normal_mode = Enum(['Per vertex', 'Per face']) dsname = CStr('output', desc='Name of the datasource within the pipeline to use as a source of triangles (should be a TriangularMesh object)') _datasource_choices = List() _datasource_keys = List() def __init__(self, pipeline, method='wireframe', dsname='', context=None, **kwargs): EngineLayer.__init__(self, context=context, **kwargs) self._pipeline = pipeline self.engine = None self.cmap = 'gist_rainbow' self.x_key = 'x' # TODO - make these traits? self.y_key = 'y' self.z_key = 'z' self.xn_key = 'xn' self.yn_key = 'yn' self.zn_key = 'zn' self._bbox = None # define a signal so that people can be notified when we are updated (currently used to force a redraw when # parameters change) self.on_update = dispatch.Signal() # define responses to changes in various traits self.on_trait_change(self._update, 'vertexColour') self.on_trait_change(lambda: self.on_update.send(self), 'visible') self.on_trait_change(self.update, 'cmap, clim, alpha, dsname, normal_mode') self.on_trait_change(self._set_method, 'method') # update any of our traits which were passed as command line arguments self.set(**kwargs) # update datasource and method self.dsname = dsname if self.method == method: #make sure we still call _set_method even if we start with the default method self._set_method() else: self.method = method # if we were given a pipeline, connect ourselves to the onRebuild signal so that we can automatically update # ourselves if not self._pipeline is None: self._pipeline.onRebuild.connect(self.update) @property def datasource(self): """ Return the datasource we are connected to (does not go through the pipeline for triangles_mesh). """ return self._pipeline.get_layer_data(self.dsname) #return self.datasource @property def _ds_class(self): # from PYME.experimental import triangle_mesh from PYME.experimental import _triangle_mesh as triangle_mesh return triangle_mesh.TrianglesBase def _set_method(self): self.engine = ENGINES[self.method](self._context) self.update() def _get_cdata(self): try: cdata = self.datasource[self.vertexColour] except (KeyError, TypeError): cdata = np.array([0, 1]) return cdata def _update(self, *args, **kwargs): #pass cdata = self._get_cdata() self.clim = [float(cdata.min()), float(cdata.max())] self.update(*args, **kwargs) def update(self, *args, **kwargs): self._datasource_choices = [k for k, v in self._pipeline.dataSources.items() if isinstance(v, self._ds_class)] if not self.datasource is None: dks = ['constant',] if hasattr(self.datasource, 'keys'): dks = dks + sorted(self.datasource.keys()) self._datasource_keys = dks if not (self.engine is None or self.datasource is None): print('lw update') self.update_from_datasource(self.datasource) self.on_update.send(self) @property def bbox(self): return self._bbox def update_from_datasource(self, ds): """ Pulls vertices/normals from a binary STL file. See PYME.IO.FileUtils.stl for more info. Calls update_data on the input. Parameters ---------- ds : PYME.experimental.triangular_mesh.TriangularMesh object Returns ------- None """ #t = ds.vertices[ds.faces] #n = ds.vertex_normals[ds.faces] x, y, z = ds.vertices[ds.faces].reshape(-1, 3).T if self.normal_mode == 'Per vertex': xn, yn, zn = ds.vertex_normals[ds.faces].reshape(-1, 3).T else: xn, yn, zn = np.repeat(ds.face_normals.T, 3, axis=1) if self.vertexColour in ['', 'constant']: c = np.ones(len(x)) clim = [0, 1] #elif self.vertexColour == 'vertex_index': # c = np.arange(0, len(x)) else: c = ds[self.vertexColour][ds.faces].ravel() clim = self.clim cmap = getattr(cm, self.cmap) alpha = float(self.alpha) # Do we have coordinates? Concatenate into vertices. if x is not None and y is not None and z is not None: vertices = np.vstack((x.ravel(), y.ravel(), z.ravel())) self._vertices = vertices.T.ravel().reshape(len(x.ravel()), 3) if not xn is None: self._normals = np.vstack((xn.ravel(), yn.ravel(), zn.ravel())).T.ravel().reshape(len(x.ravel()), 3) else: self._normals = -0.69 * np.ones(self._vertices.shape) self._bbox = np.array([x.min(), y.min(), z.min(), x.max(), y.max(), z.max()]) else: self._bbox = None # TODO: This temporarily sets all triangles to the color red. User should be able to select color. if c is None: c = np.ones(self._vertices.shape[0]) * 255 # vector of pink if clim is not None and c is not None and cmap is not None: cs_ = ((c - clim[0]) / (clim[1] - clim[0])) cs = cmap(cs_) if self.method in ['flat', 'tessel']: alpha = cs_ * alpha cs[:, 3] = alpha if self.method == 'tessel': cs = np.power(cs, 0.333) self._colors = cs.ravel().reshape(len(c), 4) else: # cs = None if not self._vertices is None: self._colors = np.ones((self._vertices.shape[0], 4), 'f') self._alpha = alpha self._color_map = cmap self._color_limit = clim def get_vertices(self): return self._vertices def get_normals(self): return self._normals def get_colors(self): return self._colors def get_color_map(self): return self._color_map @property def colour_map(self): return self._color_map def get_color_limit(self): return self._color_limit @property def default_view(self): from traitsui.api import View, Item, Group, InstanceEditor, EnumEditor from PYME.ui.custom_traits_editors import HistLimitsEditor, CBEditor return View([Group([Item('dsname', label='Data', editor=EnumEditor(name='_datasource_choices')), ]), Item('method'), Item('normal_mode', visible_when='method=="shaded"'), Item('vertexColour', editor=EnumEditor(name='_datasource_keys'), label='Colour'), Group([Item('clim', editor=HistLimitsEditor(data=self._get_cdata), show_label=False), ], visible_when='vertexColour != "constant"'), Group([Item('cmap', label='LUT'), Item('alpha', visible_when='method in ["flat", "tessel"]') ]) ], ) # buttons=['OK', 'Cancel']) def default_traits_view(self): return self.default_view
class InterpolatePSF(ModuleBase): """ Interpolate PSF with RBF. Very stupid. Very slow. Performed on local pixels and combine by tiling. Only uses the first color channel """ inputName = Input('input') rbf_radius = Float(250.0) target_voxelsize = List(Float, [100., 100., 100.]) output_images = Output('psf_interpolated') def execute(self, namespace): ims = namespace[self.inputName] data = ims.data[:,:,:,0] dims_original = list() voxelsize = [ims.mdh.voxelsize.x, ims.mdh.voxelsize.y, ims.mdh.voxelsize.z] for dim, dim_len in enumerate(data.shape): d = np.linspace(0, dim_len-1, dim_len) * voxelsize[dim] * 1E3 d -= d.mean() dims_original.append(d) X, Y, Z = np.meshgrid(*dims_original, indexing='ij') dims_interpolated = list() for dim, dim_len in enumerate(data.shape): tar_len = int(np.ceil((voxelsize[dim]*1E3 * dim_len) / self.target_voxelsize[dim])) d = np.arange(tar_len) * self.target_voxelsize[dim] d -= d.mean() dims_interpolated.append(d) X_interp, Y_interp, Z_interp = np.meshgrid(*dims_interpolated, indexing='ij') pts_interp = zip(*[X_interp.flatten(), Y_interp.flatten(), Z_interp.flatten()]) results = np.zeros(X_interp.size) for i, pt in enumerate(pts_interp): results[i] = self.InterpolateAt(pt, X, Y, Z, data[:,:,:]) if i % 100 == 0: print("{} out of {} completed.".format(i, results.shape[0])) results = results.reshape(len(dims_interpolated[0]), len(dims_interpolated[1]), len(dims_interpolated[2])) # return results new_mdh = None try: new_mdh = MetaDataHandler.NestedClassMDHandler(ims.mdh) new_mdh["voxelsize.x"] = self.target_voxelsize[0] * 1E-3 new_mdh["voxelsize.y"] = self.target_voxelsize[1] * 1E-3 new_mdh["voxelsize.z"] = self.target_voxelsize[2] * 1E-3 new_mdh['Interpolation.Method'] = 'RBF' new_mdh['Interpolation.RbfRadius'] = self.rbf_radius except Exception as e: print(e) namespace[self.output_images] = ImageStack(data=results, mdh=new_mdh) def InterpolateAt(self, pt, X, Y, Z, data, radius=250.): X_subset, Y_subset, Z_subset, data_subset = self.GetPointsInNeighbourhood(pt, X, Y, Z, data, radius) rbf = interpolate.Rbf(X_subset, Y_subset, Z_subset, data_subset, function="cubic", smooth=1E3)#, norm=euclidean_norm_numpy) # print pt return rbf(*pt) def GetPointsInNeighbourhood(self, center, X, Y, Z, data, radius=250.): distance = np.sqrt((X-center[0])**2 + (Y-center[1])**2 + (Z-center[2])**2) # print distance.shape mask = distance < 250. return X[mask], Y[mask], Z[mask], data[mask]
class AveragePSF(ModuleBase): """ Input stacks of PSF and return the (normalized) average PSF. Additional filter based on max error/residual between image and averaged image. """ inputName = Input('psf_aligned') normalize_intensity = Bool(False) # normalize_z = Bool(False) output_var_image = Output('psf_var') # smoothing_method = Enum(['RBF', 'Gaussian']) # output_var_image_norm = Output('psf_var_norm') gaussian_filter = List(Float, [0, 0, 0], 3, 3) residual_threshold = Float(0.1) output_images = Output('psf_combined') def execute(self, namespace): ims = namespace[self.inputName] psf_raw = ims.data[:,:,:,:] # always normalize first, since needed for statistics psf_raw_norm = psf_raw.copy() # if self.normalize_intensity == True: psf_raw_norm /= psf_raw_norm.max(axis=(0,1,2), keepdims=True) psf_raw_norm /= psf_raw_norm.sum(axis=(0,1), keepdims=True) psf_raw_norm -= psf_raw_norm.min() psf_raw_norm /= psf_raw_norm.max() residual_max = np.abs(psf_raw_norm - psf_raw_norm.mean(axis=3, keepdims=True)).max(axis=(0,1,2)) print(residual_max) mask = residual_max < self.residual_threshold print "images ignore: {}".format(np.argwhere(~mask)[:,0]) print mask psf_raw_norm = psf_raw_norm[:,:,:,mask] print(psf_raw_norm.shape) psf_raw_norm -= psf_raw_norm.min() psf_raw_norm /= psf_raw_norm.max() psf_var = psf_raw_norm.var(axis=3) # psf_var_norm = psf_var / psf_combined.mean(axis=3) namespace[self.output_var_image] = ImageStack(psf_var, mdh=ims.mdh) # namespace[self.output_var_image_norm] = ImageStack(np.nan_to_num(psf_var_norm), mdh=ims.mdh) # if requested not to normalize, revert back to original data if not self.normalize_intensity: psf_raw_norm = psf_raw.copy()[:,:,:,mask] psf_combined = psf_raw_norm.mean(axis=3) psf_combined -= psf_combined.min() psf_combined /= psf_combined.max() # if self.smoothing_method == 'RBF': # dims = [np.arange(i) for i in psf_combined.shape] # elif self.smoothing_method == 'Gaussian' and if np.any(np.asarray(self.gaussian_filter)!=0): psf_processed = ndimage.gaussian_filter(psf_combined, self.gaussian_filter) else: psf_processed = psf_combined psf_processed -= psf_processed.min() psf_processed /= psf_processed.max() new_mdh = None try: new_mdh = MetaDataHandler.NestedClassMDHandler(ims.mdh) new_mdh["PSFExtraction.GaussianFilter"] = self.gaussian_filter new_mdh["PSFExtraction.NormalizeIntensity"] = self.normalize_intensity except Exception as e: print(e) namespace[self.output_images] = ImageStack(psf_processed, mdh=new_mdh) if True: fig, axes = pyplot.subplots(2, 3, figsize=(9,6)) axes[0,0].set_title('X') axes[0,0].plot(psf_raw_norm[:, psf_raw_norm.shape[1]//2, psf_raw_norm.shape[2]//2, :]) axes[1,0].plot(psf_combined[:, psf_combined.shape[1]//2, psf_combined.shape[2]//2], lw=1, color='red') axes[1,0].plot(psf_processed[:, psf_processed.shape[1]//2, psf_processed.shape[2]//2], lw=1, ls='--', color='black') axes[0,1].set_title('Y') axes[0,1].plot(psf_raw_norm[psf_raw_norm.shape[0]//2, :, psf_raw_norm.shape[2]//2, :]) axes[1,1].plot(psf_combined[psf_combined.shape[0]//2, :, psf_combined.shape[2]//2], lw=1, color='red') axes[1,1].plot(psf_processed[psf_processed.shape[0]//2, :, psf_processed.shape[2]//2], lw=1, ls='--', color='black') axes[0,2].set_title('Z') axes[0,2].plot(psf_raw_norm[psf_raw_norm.shape[0]//2, psf_raw_norm.shape[1]//2, :, :]) axes[1,2].plot(psf_combined[psf_combined.shape[0]//2, psf_combined.shape[1]//2, :], lw=1, color='red') axes[1,2].plot(psf_processed[psf_processed.shape[0]//2, psf_processed.shape[1]//2, :], lw=1, ls='--', color='black') fig.tight_layout() fig, ax = pyplot.subplots(1, 1, figsize=(4,3)) ax.hist(residual_max, bins=20) ax.axvline(self.residual_threshold, color='red', ls='--')
class CropPSF(ModuleBase): """ Crops out PSF based on positions given. Built-in filter by flattened index. Built-in filter for removing multiple peaked data. Filters work on flatten X, Y images Stacked in the 4 dimension """ inputName = Input('input') input_pos = Input('psf_pos') ignore_pos = List(Int, []) threshold_reject = Float(0.5) com_reject = Float(2.0) half_roi_x = Int(20) half_roi_y = Int(20) half_roi_z = Int(60) output_images = Output('psf_cropped') def execute(self, namespace): ims = namespace[self.inputName] psf_pos = namespace[self.input_pos] res = np.zeros((self.half_roi_x*2+1, self.half_roi_y*2+1, self.half_roi_z*2+1, sum([ar.shape[0] for ar in psf_pos]))) mask = np.ones(res.shape[3], dtype=bool) counter = 0 for c in np.arange(ims.data.shape[3]): for i in np.arange(len(psf_pos[c])): # print psf_pos[c][i][:3] x, y, z = psf_pos[c][i][:3] x_slice = slice(x-self.half_roi_x, x+self.half_roi_x+1) y_slice = slice(y-self.half_roi_y, y+self.half_roi_y+1) z_slice = slice(z-self.half_roi_z, z+self.half_roi_z+1) res[:, :, :, counter] = ims.data[x_slice, y_slice, z_slice, c].squeeze() crop_flatten = res[:, :, :, counter].mean(2) failed = False labeled_image, labeled_counts = ndimage.label(crop_flatten > crop_flatten.max() * self.threshold_reject) if labeled_counts > 1: failed = True else: com = np.asarray(ndimage.center_of_mass(crop_flatten, labeled_image, 1)) img_center = np.asarray([(s-1)*0.5 for s in labeled_image.shape]) dist = np.linalg.norm(com - img_center) # print(com, img_center, dist) if dist > self.com_reject: failed = True if failed and counter not in self.ignore_pos: self.ignore_pos.append(counter) counter += 1 # To do: add metadata # mdh['ImageType=']='PSF' print "images ignore: {}".format(self.ignore_pos) mask[self.ignore_pos] = False new_mdh = None try: new_mdh = MetaDataHandler.NestedClassMDHandler(ims.mdh) new_mdh["ImageType"] = 'PSF' if not "PSFExtraction.SourceFilenames" in new_mdh.keys(): new_mdh["PSFExtraction.SourceFilenames"] = ims.filename except Exception as e: print(e) namespace[self.output_images] = ImageStack(data=res[:,:,:,mask], mdh=new_mdh)
class TrackRenderLayer(EngineLayer): """ A layer for viewing tracking data """ # properties to show in the GUI. Note that we also inherit 'visible' from BaseLayer vertexColour = CStr('', desc='Name of variable used to colour our points') cmap = Enum(*cm.cmapnames, default='gist_rainbow', desc='Name of colourmap used to colour points') clim = ListFloat( [0, 1], desc='How our variable should be scaled prior to colour mapping') alpha = Float(1.0, desc='Tranparency') line_width = Float(1.0, desc='Track line width') method = Enum(*ENGINES.keys(), desc='Method used to display tracks') clump_key = CStr('clumpIndex', desc="Name of column containing the track identifier") dsname = CStr( 'output', desc= 'Name of the datasource within the pipeline to use as a source of points' ) _datasource_keys = List() _datasource_choices = List() def __init__(self, pipeline, method='tracks', dsname='', context=None, **kwargs): EngineLayer.__init__(self, context=context, **kwargs) self._pipeline = pipeline self.engine = None self.cmap = 'gist_rainbow' self.x_key = 'x' #TODO - make these traits? self.y_key = 'y' self.z_key = 'z' self._bbox = None # define a signal so that people can be notified when we are updated (currently used to force a redraw when # parameters change) self.on_update = dispatch.Signal() # define responses to changes in various traits self.on_trait_change(self._update, 'vertexColour') self.on_trait_change(lambda: self.on_update.send(self), 'visible') self.on_trait_change(self.update, 'cmap, clim, alpha, dsname, clump_key') self.on_trait_change(self._set_method, 'method') # update any of our traits which were passed as command line arguments self.set(**kwargs) # update datasource name and method #logger.debug('Setting dsname and method') self.dsname = dsname self.method = method self._set_method() # if we were given a pipeline, connect ourselves to the onRebuild signal so that we can automatically update # ourselves if not self._pipeline is None: self._pipeline.onRebuild.connect(self.update) @property def datasource(self): """ Return the datasource we are connected to (through our dsname property). """ return self._pipeline.get_layer_data(self.dsname) def _set_method(self): #logger.debug('Setting layer method to %s' % self.method) self.engine = ENGINES[self.method](self._context) self.update() def _get_cdata(self): try: if isinstance(self.datasource, ClumpManager): cdata = [] for track in self.datasource.all: cdata.extend(track[self.vertexColour]) cdata = np.array(cdata) else: # Assume tabular dataset cdata = self.datasource[self.vertexColour] except KeyError: cdata = np.array([0, 1]) return cdata def _update(self, *args, **kwargs): cdata = self._get_cdata() self.clim = [float(np.nanmin(cdata)), float(np.nanmax(cdata))] #self.update(*args, **kwargs) def update(self, *args, **kwargs): print('lw update') self._datasource_choices = self._pipeline.layer_data_source_names if not self.datasource is None: if isinstance(self.datasource, ClumpManager): # Grab the keys from the first Track in the ClumpManager self._datasource_keys = sorted(self.datasource[0].keys()) else: # Assume we have a tabular data source self._datasource_keys = sorted(self.datasource.keys()) if not (self.engine is None or self.datasource is None): self.update_from_datasource(self.datasource) self.on_update.send(self) @property def bbox(self): return self._bbox def update_from_datasource(self, ds): if isinstance(ds, ClumpManager): x = [] y = [] z = [] c = [] self.clumpSizes = [] # Copy data from tracks. This is already in clump order # thanks to ClumpManager for track in ds.all: x.extend(track['x']) y.extend(track['y']) z.extend(track['z']) self.clumpSizes.append(track.nEvents) if not self.vertexColour == '': c.extend(track[self.vertexColour]) else: c.extend([0 for i in track['x']]) x = np.array(x) y = np.array(y) z = np.array(z) c = np.array(c) # print(x,y,z,c) # print(x.shape,y.shape,z.shape,c.shape) else: # Assume tabular data source x, y = ds[self.x_key], ds[self.y_key] if not self.z_key is None: try: z = ds[self.z_key] except KeyError: z = 0 * x else: z = 0 * x if not self.vertexColour == '': c = ds[self.vertexColour] else: c = 0 * x # Work out clump start and finish indices # TODO - optimize / precompute???? ci = ds[self.clump_key] NClumps = int(ci.max()) clist = [[] for i in range(NClumps)] for i, cl_i in enumerate(ci): clist[int(cl_i - 1)].append(i) # This and self.clumpStarts are class attributes for # compatibility with the old Track rendering layer, # PYME.LMVis.gl_render3D.TrackLayer self.clumpSizes = [len(cl_i) for cl_i in clist] #reorder x, y, z, c in clump order I = np.hstack([np.array(cl) for cl in clist]).astype(np.int) x = x[I] y = y[I] z = z[I] c = c[I] self.clumpStarts = np.cumsum([ 0, ] + self.clumpSizes) #do normal vertex stuff vertices = np.vstack((x.ravel(), y.ravel(), z.ravel())) vertices = vertices.T.ravel().reshape(len(x.ravel()), 3) self._vertices = vertices self._normals = -0.69 * np.ones(vertices.shape) self._bbox = np.array( [x.min(), y.min(), z.min(), x.max(), y.max(), z.max()]) clim = self.clim cmap = getattr(cm, self.cmap) if clim is not None: cs_ = ((c - clim[0]) / (clim[1] - clim[0])) cs = cmap(cs_) cs[:, 3] = float(self.alpha) self._colors = cs.ravel().reshape(len(c), 4) else: if not vertices is None: self._colors = np.ones((vertices.shape[0], 4), 'f') self._color_map = cmap self._color_limit = clim self._alpha = float(self.alpha) def get_vertices(self): return self._vertices def get_normals(self): return self._normals def get_colors(self): return self._colors def get_color_map(self): return self._color_map @property def colour_map(self): return self._color_map def get_color_limit(self): return self._color_limit @property def default_view(self): from traitsui.api import View, Item, Group, InstanceEditor, EnumEditor, TextEditor from PYME.ui.custom_traits_editors import HistLimitsEditor, CBEditor return View([ Group([ Item('dsname', label='Data', editor=EnumEditor(name='_datasource_choices')), ]), Item('method'), Item( 'vertexColour', editor=EnumEditor(name='_datasource_keys'), label='Colour', visible_when='cmap not in ["R", "G", "B", "C", "M","Y", "K"]'), Group( [ Item('clim', editor=HistLimitsEditor(data=self._get_cdata, update_signal=self.on_update), show_label=False), ], visible_when='cmap not in ["R", "G", "B", "C", "M","Y", "K"]'), Group( Item('cmap', label='LUT'), Item('alpha', editor=TextEditor(auto_set=False, enter_set=True, evaluate=float)), Item('line_width')) ]) #buttons=['OK', 'Cancel']) def default_traits_view(self): return self.default_view
class LoadDriftandInterp(ModuleBase): """ Loads drift data from file(s) and use them to create a spline interpolator (``scipy.interpolate.UnivariateSpline``). Inputs ------ input_dummy : None Blank input. Required to run correctly. Outputs ------- output_drift_interpolator : Drift interpolator. Returns drift when called with frame number / time. output_drift_plot : Plot Plot of the original and interpolated drift. Parameters ---------- load_paths : list of File List of files to load. degree_of_spline : int Degree of the smoothing spline. smoothing_factor : float Smoothing factor. """ input_dummy = Input('input') # breaks GUI without this??? # load_path = File() load_paths = List(File, [""], 1) degree_of_spline = Int(3) # 1 for linear, 3 for cubic smoothing_factor = Float( -1) # 0 for no smoothing. set to negative for UnivariateSpline defulat # input_drift_raw = Input('drift_raw') output_drift_interpolator = Output('drift_interpolator') output_drift_plot = Output('drift_plot') # output_drift_raw= Input('drift_raw') def execute(self, namespace): spl_array = list() t_min = np.inf t_max = 0 tIndexes = list() drifts = list() for fil in self.load_paths: data = np.load(fil) tIndex = data['tIndex'] t_min = min(t_min, tIndex[0]) t_max = max(t_max, tIndex[-1]) drift = data['drift'] tIndexes.append(tIndex) drifts.append(drift) spl = interpolate_drift(tIndex, drift, self.degree_of_spline, self.smoothing_factor) spl_array.append(spl) # print(len(spl_array)) # print(spl_array[0]) spl_array = zip(*spl_array) # print(len(spl_final)) # print(spl_final[0]) def spl_method(funcs, t): return np.sum([f(t) for f in funcs], axis=0) spl_combined = list() for spl in spl_array: # print(spl) # spl_combined.append(lambda x: np.sum([f(x) for f in spl], axis=0)) spl_combined.append(partial(spl_method, spl)) namespace[self.output_drift_interpolator] = spl_combined # non essential, only for plotting out drift data namespace[self.output_drift_plot] = Plot( partial(generate_drift_plot, tIndexes, drifts, spl_combined)) namespace[self.output_drift_plot].plot()