class PSFSettings(HasTraits): wavelength_nm = Float(700.) NA = Float(1.47) vectorial = Bool(False) zernike_modes = Dict() zernike_modes_lower = Dict() phases = List([0, .5, 1, 1.5]) four_pi = Bool(False) def default_traits_view(self): from traitsui.api import View, Item #from PYME.ui.custom_traits_editors import CBEditor return View(Item(name='wavelength_nm'), Item(name='NA'), Item(name='vectorial'), Item(name='four_pi', label='4Pi'), Item(name='zernike_modes'), Item(name='zernike_modes_lower', visible_when='four_pi==True'), Item(name='phases', visible_when='four_pi==True', label='phases/pi'), resizable=True, buttons=['OK'])
class PointFeatureBase(ModuleBase): """ common base class for feature extraction routines - implements normalisation and PCA routines """ outputColumnName = CStr('features') columnForEachFeature = Bool( False ) #if true, outputs a column for each feature - useful for visualising normalise = Bool(True) #subtract mean and divide by std. deviation PCA = Bool( True ) # reduce feature dimensionality by performing PCA - TODO - should this be a separate module and be chained instead? PCA_components = Int(3) # 0 = same dimensionality as features def _process_features(self, data, features): from PYME.IO import tabular out = tabular.MappingFilter(data) out.mdh = getattr(data, 'mdh', None) if self.normalise: features = features - features.mean(0)[None, :] features = features / features.std(0)[None, :] if self.PCA: from sklearn.decomposition import PCA pca = PCA(n_components=( self.PCA_components if self.PCA_components > 0 else None )).fit(features) features = pca.transform(features) out.pca = pca #save the pca object just in case we want to look at what the principle components are (this is hacky) out.addColumn(self.outputColumnName, features) if self.columnForEachFeature: for i in range(features.shape[1]): out.addColumn('feat_%d' % i, features[:, i]) return out
class ArithmaticFilter(ModuleBase): """ Module with two image inputs and one image output Parameters ---------- inputName0: PYME.IO.image.ImageStack inputName1: PYME.IO.image.ImageStack outputName: PYME.IO.image.ImageStack """ inputName0 = Input('input') inputName1 = Input('input') outputName = Output('filtered_image') processFramesIndividually = Bool(False) def filter(self, image0, image1): if self.processFramesIndividually: filt_ims = [] for chanNum in range(image0.data.shape[3]): out = [] for i in range(image0.data.shape[2]): d0 = image0.data[:, :, i, chanNum].squeeze().astype('f') d1 = image1.data[:, :, i, chanNum].squeeze().astype('f') out.append( np.atleast_3d( self.applyFilter(d0, d1, chanNum, i, image0))) filt_ims.append(np.concatenate(out, 2)) else: filt_ims = [] for chanNum in range(image0.data.shape[3]): d0 = image0.data[:, :, :, chanNum].squeeze().astype('f') d1 = image1.data[:, :, :, chanNum].squeeze().astype('f') filt_ims.append( np.atleast_3d(self.applyFilter(d0, d1, chanNum, 0, image0))) im = ImageStack(filt_ims, titleStub=self.outputName) im.mdh.copyEntriesFrom(image0.mdh) im.mdh['Parents'] = '%s, %s' % (image0.filename, image1.filename) self.completeMetadata(im) return im def execute(self, namespace): namespace[self.outputName] = self.filter(namespace[self.inputName0], namespace[self.inputName1]) def completeMetadata(self, im): pass
class PointFeaturesPairwiseDist(PointFeatureBase): """ Create a feature vector for each point in a point-cloud using a histogram of it's distances to all other points """ inputLocalisations = Input('localisations') outputName = Output('features') binWidth = Float(100.) # width of the bins in nm numBins = Int(20) #number of bins (starting at 0) threeD = Bool(True) normaliseRelativeDensity = Bool( False ) # divide by the sum of all radial bins. If not performed, the first principle component will likely be average density def execute(self, namespace): from PYME.Analysis.points import DistHist points = namespace[self.inputLocalisations] if self.threeD: x, y, z = points['x'], points['y'], points['z'] f = np.array([ DistHist.distanceHistogram3D(x[i], y[i], z[i], x, y, z, self.numBins, self.binWidth) for i in xrange(len(x)) ]) else: x, y = points['x'], points['y'] f = np.array([ DistHist.distanceHistogram(x[i], y[i], x, y, self.numBins, self.binWidth) for i in xrange(len(x)) ]) namespace[self.outputName] = self._process_features(points, f)
class EngineLayer(BaseLayer): """ Base class for layers who delegate their rendering to an engine. """ engine = Instance(BaseEngine) show_lut = Bool(True) def render(self, gl_canvas): if self.visible: return self.engine.render(gl_canvas, self) @abc.abstractmethod def get_vertices(self): """ Provides the engine with a way of obtaining vertex data. Should be over-ridden in derived class Returns ------- a numpy array of vertices suitable for passing to glVertexPointerf() """ raise(NotImplementedError()) @abc.abstractmethod def get_normals(self): """ Provides the engine with a way of obtaining vertex data. Should be over-ridden in derived class Returns ------- a numpy array of normals suitable for passing to glNormalPointerf() """ raise (NotImplementedError()) @abc.abstractmethod def get_colors(self): """ Provides the engine with a way of obtaining vertex data. Should be over-ridden in derived class Returns ------- a numpy array of vertices suitable for passing to glColorPointerf() """ raise (NotImplementedError())
class ExtractChannelByName(ModuleBase): """Extract one channel from an image using regular expression matching to image channel names - by default this is case insensitive""" inputName = Input('input') outputName = Output('filtered_image') channelNamePattern = CStr('channel0') caseInsensitive = Bool(True) def _matchChannels(self, channelNames): # we put this into its own function so that we can call it externally for testing import re flags = 0 if self.caseInsensitive: flags |= re.I idxs = [ i for i, c in enumerate(channelNames) if re.search(self.channelNamePattern, c, flags) ] return idxs def _pickChannel(self, image): channelNames = image.mdh['ChannelNames'] idxs = self._matchChannels(channelNames) if len(idxs) < 1: raise RuntimeError( "Expression '%s' did not match any channel names" % self.channelNamePattern) if len(idxs) > 1: raise RuntimeError( ("Expression '%s' did match more than one channel name: " % self.channelNamePattern) + ', '.join([channelNames[i] for i in idxs])) idx = idxs[0] chan = image.data[:, :, :, idx] im = ImageStack(chan, titleStub='Filtered Image') im.mdh.copyEntriesFrom(image.mdh) im.mdh['ChannelNames'] = [channelNames[idx]] im.mdh['Parent'] = image.filename return im def execute(self, namespace): namespace[self.outputName] = self._pickChannel( namespace[self.inputName])
class BaseLayer(HasTraits): """ This class represents a layer that should be rendered. It should represent a fairly high level concept of a layer - e.g. a Point-cloud of data coming from XX, or a Surface representation of YY. If such a layer can be rendered multiple different but similar ways (e.g. points/pointsprites or shaded/wireframe etc) which otherwise share common settings e.g. point size, point colour, etc ... these representations should be coded as one layer with a selectable rendering backend or 'engine' responsible for managing shaders and actually executing the opengl code. In this case use the `EngineLayer` class as a base . In simpler cases, such as rendering an overlay it is acceptable for a layer to do it's own rendering and manage it's own shader. In this case, use `SimpleLayer` as a base. """ visible = Bool(True) def __init__(self, context=None, **kwargs): self._context = context #HasTraits.__init__(**kwargs) @property def bbox(self): """Bounding box in form [x0,y0,z0, x1,y1,z1] (or none if a bounding box does not make sense for this layer) over-ride in derived classes """ return None @abc.abstractmethod def render(self, gl_canvas): """ Abstract render method to be over-ridden in derived classes. Should check self.visible before drawing anything. Parameters ---------- gl_canvas : the canvas to draw to - an instance of PYME.LMVis.gl_render3D_shaders.LMGLShaderCanvas """ pass
class LabelByRegionProperty(Filter): """Asigns a region property to each contiguous region in the input mask. Optionally throws away all regions for which property is outside a given range. """ regionProperty = Enum(['area', 'circularity', 'aspectratio']) filterByProperty = Bool(False) propertyMin = Float(0) propertyMax = Float(1e6) def applyFilter(self, data, chanNum, frNum, im): mask = data > 0.5 labs, nlabs = ndimage.label(mask) rp = skimage.measure.regionprops(labs, None, cache=True) m2 = np.zeros_like(mask, dtype='float') objs = ndimage.find_objects(labs) for region in rp: oslices = objs[region.label - 1] r = labs[oslices] == region.label #print r.shape if self.regionProperty == 'area': propValue = region.area elif self.regionProperty == 'aspectratio': propValue = region.major_axis_length / region.minor_axis_length elif self.regionProperty == 'circularity': propValue = 4 * math.pi * region.area / (region.perimeter * region.perimeter) if self.filterByProperty: if (propValue >= self.propertyMin) and (propValue <= self.propertyMax): m2[oslices] += r * propValue else: m2[oslices] += r * propValue return m2 def completeMetadata(self, im): im.mdh['Labelling.Property'] = self.regionProperty
class DetectPSF(ModuleBase): """ Detect PSF based on diff of gaussian Image dims in X, Y, Z, C where C are processed independently. Returns list of (X, Y, Z) per C """ inputName = Input('input') min_sigma = Float(1.0) max_sigma = Float(3.0) sigma_ratio = Float(1.6) percent_threshold = Float(0.1) overlap = Float(0.5) exclude_border = Int(50) ignore_z = Bool(True) output_pos = Output('psf_pos') # output_img = Output('output') def execute(self, namespace): ims = namespace[self.inputName] pixel_size = ims.mdh['voxelsize.x'] pos = list() counts = ims.data.shape[3] for c in np.arange(counts): mean_project = ims.data[:,:,:,c].mean(2).squeeze() mean_project[mean_project==2**16-1] = 200 mean_project -= mean_project.min() mean_project /= mean_project.max() # if skimage is new enough to support exclude_border #blobs = feature.blob_dog(mean_project, self.min_sigma / pixel_size, self.max_sigma / pixel_size, overlap=self.overlap, threshold=self.percent_threshold*mean_project.max(), exclude_border=self.exclude_border) #otherwise: blobs = feature.blob_dog(mean_project, self.min_sigma / pixel_size, self.max_sigma / pixel_size, overlap=self.overlap, threshold=self.percent_threshold*mean_project.max()) edge_mask = (blobs[:, 0] > self.exclude_border) & (blobs[:, 0] < mean_project.shape[0] - self.exclude_border) edge_mask &= (blobs[:, 1] > self.exclude_border) & (blobs[:, 1] < mean_project.shape[1] - self.exclude_border) blobs = blobs[edge_mask] # is list of x, y, sig if self.ignore_z: blobs = np.insert(blobs, 2, ims.data.shape[2]//2, axis=1) else: raise Exception("z centering not yet implemented") blobs = blobs.astype(np.int) # print blobs pos.append(blobs) namespace[self.output_pos] = pos if True: try: # from matplotlib import pyplot fig, axes = pyplot.subplots(1, counts, figsize=(4*counts, 3), squeeze=False) for c in np.arange(counts): mean_project = ims.data[:,:,:,c].mean(2).squeeze() mean_project[mean_project==2**16-1] = 200 axes[0, c].imshow(mean_project) axes[0, c].set_axis_off() for x, y, z, sig in pos[c]: cir = pyplot.Circle((y, x), sig, color='red', linewidth=2, fill=False) axes[0, c].add_patch(cir) except Exception as e: print e
class ModuleCollection(HasTraits): modules = List() execute_on_invalidation = Bool(False) def __init__(self, *args, **kwargs): HasTraits.__init__(self, *args, **kwargs) self.namespace = {} # we open hdf files and don't necessarily read their contents into memory - these need to be closed when we # either delete the recipe, or clear the namespace self._open_input_files = [] self.recipe_changed = dispatch.Signal() self.recipe_executed = dispatch.Signal() def invalidate_data(self): if self.execute_on_invalidation: self.execute() def clear(self): self.namespace.clear() def new_output_name(self, stub): count = len([k.startswith(stub) for k in self.namespace.keys()]) if count == 0: return stub else: return '%s_%d' % (stub, count) def dependancyGraph(self): dg = {} #only add items to dependancy graph if they are not already in the namespace #calculated_objects = namespace.keys() for mod in self.modules: #print mod s = mod.inputs try: s.update(dg[mod]) except KeyError: pass dg[mod] = s for op in mod.outputs: #if not op in calculated_objects: dg[op] = { mod, } return dg def reverseDependancyGraph(self): dg = self.dependancyGraph() rdg = {} for k, vs in dg.items(): for v in vs: vdeps = set() try: vdeps = rdg[v] except KeyError: pass vdeps.add(k) rdg[v] = vdeps return rdg def _getAllDownstream(self, rdg, keys): """get all the downstream items which depend on the given key""" downstream = set() next_level = set() for k in keys: try: next_level.update(rdg[k]) except KeyError: pass if len(list(next_level)) > 0: downstream.update(next_level) downstream.update(self._getAllDownstream(rdg, list(next_level))) return downstream def prune_dependencies_from_namespace(self, keys_to_prune, keep_passed_keys=False): rdg = self.reverseDependancyGraph() if keep_passed_keys: downstream = list(self._getAllDownstream(rdg, list(keys_to_prune))) else: downstream = list(keys_to_prune) + list( self._getAllDownstream(rdg, list(keys_to_prune))) #print downstream for dsi in downstream: try: self.namespace.pop(dsi) except KeyError: #the output is not in our namespace, no need to prune pass except AttributeError: #we might not have our namespace defined yet pass def resolveDependencies(self): import toposort #build dependancy graph dg = self.dependancyGraph() #solve the dependency tree return toposort.toposort_flatten(dg, sort=False) def execute(self, **kwargs): #remove anything which is downstream from changed inputs #print self.namespace.keys() for k, v in kwargs.items(): #print k, v try: if not (self.namespace[k] == v): #input has changed print('pruning: ', k) self.prune_dependencies_from_namespace([k]) except KeyError: #key wasn't in namespace previously print('KeyError') pass self.namespace.update(kwargs) exec_order = self.resolveDependencies() for m in exec_order: if isinstance( m, ModuleBase) and not m.outputs_in_namespace(self.namespace): try: m.execute(self.namespace) except: logger.exception("Error in recipe module: %s" % m) raise self.recipe_executed.send_robust(self) if 'output' in self.namespace.keys(): return self.namespace['output'] @classmethod def fromMD(cls, md): c = cls() moduleNames = set([s.split('.')[0] for s in md.keys()]) mc = [] for mn in moduleNames: mod = all_modules[mn]() mod.set(**md[mn]) mc.append(mod) #return cls(modules=mc) c.modules = mc return c def get_cleaned_module_list(self): l = [] for mod in self.modules: #l.append({mod.__class__.__name__: mod.get()}) ct = mod.class_traits() mod_traits_cleaned = {} for k, v in mod.get().items(): if not k.startswith( '_' ): #don't save private data - this is usually used for caching etc .., try: if (not (v == ct[k].default)) or (k.startswith( 'input')) or (k.startswith('output')): #don't save defaults if isinstance(v, dict) and not type(v) == dict: v = dict(v) elif isinstance(v, list) and not type(v) == list: v = list(v) elif isinstance(v, set) and not type(v) == set: v = set(v) mod_traits_cleaned[k] = v except KeyError: # for some reason we have a trait that shouldn't be here pass l.append({module_names[mod.__class__]: mod_traits_cleaned}) return l def toYAML(self): import yaml class MyDumper(yaml.SafeDumper): def represent_mapping(self, tag, value, flow_style=None): return super(MyDumper, self).represent_mapping(tag, value, False) return yaml.dump(self.get_cleaned_module_list(), Dumper=MyDumper) def toJSON(self): import json return json.dumps(self.get_cleaned_module_list()) def _update_from_module_list(self, l): """ Update from a parsed yaml or json list of modules It probably makes no sense to call this directly as the format is pretty wack - a list of dictionarys each with a single entry, but that is how the yaml parses Parameters ---------- l: list List of modules as obtained from parsing a yaml recipe, Each module is a dictionary mapping with a single e.g. [{'Filtering.Filter': {'filters': {'probe': [-0.5, 0.5]}, 'input': 'localizations', 'output': 'filtered'}}] Returns ------- """ mc = [] if l is None: l = [] for mdd in l: mn, md = list(mdd.items())[0] try: mod = all_modules[mn](self) except KeyError: # still support loading old recipes which do not use hierarchical names # also try and support modules which might have moved mod = _legacy_modules[mn.split('.')[-1]](self) mod.set(**md) mc.append(mod) self.modules = mc self.recipe_changed.send_robust(self) self.invalidate_data() @classmethod def _from_module_list(cls, l): """ A factory method which contains the common logic for loading/creating from either yaml or json. Do not call directly""" c = cls() c._update_from_module_list(l) return c @classmethod def fromYAML(cls, data): import yaml l = yaml.load(data) return cls._from_module_list(l) def update_from_yaml(self, data): """ Update from a yaml formatted recipe description Parameters ---------- data: str either yaml formatted text, or the path to a yaml file. Returns ------- None """ import os import yaml if os.path.isfile(data): with open(data) as f: data = f.read() l = yaml.load(data) return self._update_from_module_list(l) @classmethod def fromJSON(cls, data): import json return cls._from_module_list(json.loads(data)) def add_module(self, module): self.modules.append(module) self.recipe_changed.send_robust(self) @property def inputs(self): ip = set() for mod in self.modules: ip.update({k for k in mod.inputs if k.startswith('in')}) return ip @property def outputs(self): op = set() for mod in self.modules: op.update({k for k in mod.outputs if k.startswith('out')}) return op @property def module_outputs(self): op = set() for mod in self.modules: op.update(set(mod.outputs)) return op @property def file_inputs(self): out = [] for mod in self.modules: out += mod.file_inputs return out def save(self, context={}): """ Find all OutputModule instances and call their save methods with the recipe context Parameters ---------- context : dict A context dictionary used to substitute and create variable names. """ for mod in self.modules: if isinstance(mod, OutputModule): mod.save(self.namespace, context) def gather_outputs(self, context={}): """ Find all OutputModule instances and call their generate methods with the recipe context Parameters ---------- context : dict A context dictionary used to substitute and create variable names. """ outputs = [] for mod in self.modules: if isinstance(mod, OutputModule): out = mod.generate(self.namespace, context) if not out is None: outputs.append(out) return outputs def loadInput(self, filename, key='input'): """Load input data from a file and inject into namespace Currently only handles images (anything you can open in dh5view). TODO - extend to other types. """ #modify this to allow for different file types - currently only supports images from PYME.IO import unifiedIO import os extension = os.path.splitext(filename)[1] if extension in ['.h5r', '.h5', '.hdf']: import tables from PYME.IO import MetaDataHandler from PYME.IO import tabular with unifiedIO.local_or_temp_filename(filename) as fn: with tables.open_file(fn, mode='r') as h5f: #make sure our hdf file gets closed key_prefix = '' if key == 'input' else key + '_' try: mdh = MetaDataHandler.NestedClassMDHandler( MetaDataHandler.HDFMDHandler(h5f)) except tables.FileModeError: # Occurs if no metadata is found, since we opened the table in read-mode logger.warning( 'No metadata found, proceeding with empty metadata' ) mdh = MetaDataHandler.NestedClassMDHandler() for t in h5f.list_nodes('/'): # FIXME - The following isinstance tests are not very safe (and badly broken in some cases e.g. # PZF formatted image data, Image data which is not in an EArray, etc ...) # Note that EArray is only used for streaming data! # They should ideally be replaced with more comprehensive tests (potentially based on array or dataset # dimensionality and/or data type) - i.e. duck typing. Our strategy for images in HDF should probably # also be improved / clarified - can we use hdf attributes to hint at the data intent? How do we support # > 3D data? if isinstance(t, tables.VLArray): from PYME.IO.ragged import RaggedVLArray rag = RaggedVLArray( h5f, t.name, copy=True ) #force an in-memory copy so we can close the hdf file properly rag.mdh = mdh self.namespace[key_prefix + t.name] = rag elif isinstance(t, tables.table.Table): # pipe our table into h5r or hdf source depending on the extension tab = tabular.H5RSource( h5f, t.name ) if extension == '.h5r' else tabular.HDFSource( h5f, t.name) tab.mdh = mdh self.namespace[key_prefix + t.name] = tab elif isinstance(t, tables.EArray): # load using ImageStack._loadh5, which finds metdata im = ImageStack(filename=filename, haveGUI=False) # assume image is the main table in the file and give it the named key self.namespace[key] = im elif extension == '.csv': logger.error('loading .csv not supported yet') raise NotImplementedError elif extension in ['.xls', '.xlsx']: logger.error('loading .xls not supported yet') raise NotImplementedError else: self.namespace[key] = ImageStack(filename=filename, haveGUI=False) @property def pipeline_view(self): import wx if wx.GetApp() is None: return None else: from traitsui.api import View, ListEditor, InstanceEditor, Item #v = tu.View(tu.Item('modules', editor=tu.ListEditor(use_notebook=True, view='pipeline_view'), style='custom', show_label=False), # buttons=['OK', 'Cancel']) return View(Item('modules', editor=ListEditor( style='custom', editor=InstanceEditor(view='pipeline_view'), mutable=False), style='custom', show_label=False), buttons=['OK', 'Cancel']) def to_svg(self): from . import recipeLayout return recipeLayout.to_svg(self.dependancyGraph()) def _repr_svg_(self): """ Make us look pretty in Jupyter""" return self.to_svg()
class Filter(ModuleBase): """Module with one image input and one image output""" inputName = Input('input') outputName = Output('filtered_image') processFramesIndividually = Bool(True) def filter(self, image): #from PYME.util.shmarray import shmarray #import multiprocessing if self.processFramesIndividually: filt_ims = [] for chanNum in range(image.data.shape[3]): filt_ims.append( np.concatenate([ np.atleast_3d( self.applyFilter( image.data[:, :, i, chanNum].squeeze().astype('f'), chanNum, i, image)) for i in range(image.data.shape[2]) ], 2)) else: filt_ims = [ np.atleast_3d( self.applyFilter( image.data[:, :, :, chanNum].squeeze().astype('f'), chanNum, 0, image)) for chanNum in range(image.data.shape[3]) ] im = ImageStack(filt_ims, titleStub=self.outputName) im.mdh.copyEntriesFrom(image.mdh) im.mdh['Parent'] = image.filename self.completeMetadata(im) return im def execute(self, namespace): namespace[self.outputName] = self.filter(namespace[self.inputName]) def completeMetadata(self, im): pass @classmethod def dsviewer_plugin_callback(cls, dsviewer, showGUI=True, **kwargs): """Implements a callback which allows this module to be used as a plugin for dsviewer. Parameters ---------- dsviewer : :class:`PYME.DSView.dsviewer.DSViewFrame` instance This is the current :class:`~PYME.DSView.dsviewer.DSViewFrame` instance. The filter will be run with the associated ``.image`` as input and display the output in a new window. showGUI : bool Should we show a GUI to set parameters (generated by calling configure_traits()), or just run with default parameters. **kwargs : dict Optionally, provide default values for parameters. Makes most sense when used with showGUI = False """ from PYME.DSView import ViewIm3D mod = cls(inputName='input', outputName='output', **kwargs) if (not showGUI) or mod.configure_traits(kind='modal'): namespace = {'input': dsviewer.image} mod.execute(namespace) ViewIm3D(mod['output'], parent=dsviewer, glCanvas=dsviewer.glCanvas)
class CalculateFRCBase(ModuleBase): """ Base class. Refer to derived classes for docstrings. """ pre_filter = Enum(['Tukey_1/8', None]) frc_smoothing_func = Enum(['Cubic Spline', 'Sigmoid', None]) multiprocessing = Bool(True) # plot_graphs = Bool() cubic_smoothing = Float(0.01) save_path = File() # output_fft_image_a = Output('FRC_fft_image_a') # output_fft_image_b = Output('FRC_fft_image_b') output_fft_images_cc = Output('FRC_fft_images_cc') output_frc_dict = Output('FRC_dict') output_frc_plot = Output('FRC_plot') output_frc_raw = Output('FRC_raw') def execute(self): raise Exception("Base class not fully implemented") def preprocess_images(self, image_pair): # pad images to square shape # image_pair = self.pad_images_to_equal_dims(image_pair) dims_length = np.stack([im.shape for im in image_pair], 0) assert np.all([np.all(dims_length[:, i] == dims_length[0, i])for i in range(dims_length.shape[1])]), "Images not the same dimension." # apply filtering to zero near edge of images if self.pre_filter == 'Tukey_1/8': image_pair = self.filter_images_tukey(image_pair, 1./8) elif self.pre_filter == None: pass else: raise Exception() return image_pair # def pad_images_to_equal_dims(self, images): # dims_length = np.stack([im.shape for im in images], 0) # assert np.all([np.all(dims_length[:, i] == dims_length[0, i])for i in xrange(dims_length.shape[1])]), "Images not the same dimension." # # return images # # dims_length = dims_length[0, :] # max_dim = dims_length.max() # # padding = np.empty((dims_length.shape[0], 2), dtype=np.int) # for dim in xrange(dims_length.shape[0]): # total_padding = max_dim - dims_length[dim] # padding[dim] = [total_padding // 2, total_padding - total_padding //2] # # results = list() # for im in images: # results.append(np.pad(im, padding, mode='constant', constant_values=0)) # # return results def filter_images_tukey(self, images, alpha): from scipy.signal import tukey # window = tukey(images[0].shape[0], alpha=alpha) # window_nd = np.prod(np.stack(np.meshgrid(*(window,)*images[0].ndim)), axis=0) windows = [tukey(images[0].shape[i], alpha=alpha) for i in range(images[0].ndim)] window_nd = np.prod(np.stack(np.meshgrid(*windows, indexing='ij')), axis=0) # for im in images: # im *= window_nd return [images[0]*window_nd, images[1]*window_nd] def calculate_FRC_from_images(self, image_pair, mdh): ft_images = list() if self.multiprocessing: results = list() for im in image_pair: results.append(self._pool.apply_async(np.fft.fftn, (im,))) for res in results: ft_images.append(res.get()) del results else: for im in image_pair: ft_images.append(np.fft.fftn(im)) # im_fft_freq = np.fft.fftfreq(image_pair[0].shape[0], self._pixel_size_in_nm) # im_R = np.sqrt(im_fft_freq[:, None]**2 + im_fft_freq[None, :]**2) im_fft_freqs = [np.fft.fftfreq(image_pair[0].shape[i], self._pixel_size_in_nm[i]) for i in range(image_pair[0].ndim)] im_R = np.linalg.norm(np.stack(np.meshgrid(*im_fft_freqs, indexing='ij')), axis=0) im1_fft_power = np.multiply(ft_images[0], np.conj(ft_images[0])) im2_fft_power = np.multiply(ft_images[1], np.conj(ft_images[1])) im12_fft_power = np.multiply(ft_images[0], np.conj(ft_images[1])) ## fft_ims = ImageStack(data=np.stack([np.fft.fftshift(im1_fft_power), ## np.fft.fftshift(im2_fft_power), ## np.fft.fftshift(im12_fft_power)], axis=-1), mdh=mdh) ## self._namespace[self.output_fft_images] = fft_ims # self._namespace[self.output_fft_image_a] = ImageStack(data=np.fft.fftshift(im1_fft_power), titleStub="ImageA_FFT") # self._namespace[self.output_fft_image_b] = ImageStack(data=np.fft.fftshift(im2_fft_power), titleStub="ImageB_FFT") # self._namespace[self.output_fft_images_cc] = ImageStack(data=np.fft.fftshift(im12_fft_power), titleStub="ImageA_Image_B_FFT_CC") try: self._namespace[self.output_fft_images_cc] = ImageStack(data=np.stack([np.atleast_3d(np.fft.fftshift(im1_fft_power)), np.atleast_3d(np.fft.fftshift(im2_fft_power)), np.atleast_3d(np.fft.fftshift(im12_fft_power))], 3), titleStub="ImageA_Image_FFT_CC") # if self.plot_graphs: # from PYME.DSView.dsviewer import ViewIm3D, View3D # # ViewIm3D(self._namespace[self.output_fft_image_a]) # # ViewIm3D(self._namespace[self.output_fft_image_b]) # ViewIm3D(self._namespace[self.output_fft_images_cc]) # # View3D(np.fft.fftshift(im_R)) except Exception as e: print (e) im1_fft_flat_res = CalculateFRCBase.BinData(im_R.flatten(), im1_fft_power.flatten(), statistic='mean', bins=201) im2_fft_flat_res = CalculateFRCBase.BinData(im_R.flatten(), im2_fft_power.flatten(), statistic='mean', bins=201) im12_fft_flat_res = CalculateFRCBase.BinData(im_R.flatten(), im12_fft_power.flatten(), statistic='mean', bins=201) corr = np.real(im12_fft_flat_res.statistic) / np.sqrt(np.abs(im1_fft_flat_res.statistic*im2_fft_flat_res.statistic)) smoothed_frc = self.smooth_frc(im12_fft_flat_res.bin_edges[:-1], corr, self.cubic_smoothing) res, rawdata = self.calculate_threshold(im12_fft_flat_res.bin_edges[:-1], corr, smoothed_frc, im12_fft_flat_res.counts) return res, rawdata def smooth_frc(self, freq, corr, cubic_smoothing): if self.frc_smoothing_func is None: interp_frc = interpolate.interp1d(freq, corr, kind='next', ) return interp_frc elif self.frc_smoothing_func == "Sigmoid": func = CalculateFRCBase.Sigmoid fit_res = optimize.minimize(lambda a, x: np.sum(np.square(func(a[0], a[1], x)-corr)), [1, freq[len(freq)/2]], args=(freq), method='Nelder-Mead') return partial(func, n=fit_res.x[0], c=fit_res.x[1]) elif self.frc_smoothing_func == "Cubic Spline": # smoothed so that average deviation loss is less than 0.2% of original. Somewhat arbitrary but probably not totally unreasonable since FRC is bounded 0 to 1. # interp_frc = interpolate.UnivariateSpline(freq, corr, k=3, s=len(freq)*(0.002*np.var(corr))) # interp_frc = interpolate.UnivariateSpline(freq, corr, k=3, s=(0.05*np.std(corr))) interp_frc = interpolate.UnivariateSpline(freq, corr, k=3, s=cubic_smoothing) return interp_frc def calculate_threshold(self, freq, corr, corr_func, counts): res = dict() fsc_0143 = optimize.minimize(lambda x: np.square(corr_func(x=x)-0.143), freq[np.argmax(corr_func(x=freq)-0.143 < 0)], method='Nelder-Mead') res['frc 1/7'] = 1./fsc_0143.x[0] sigma = 1.0 / np.sqrt(counts*0.5) sigma_spl = interpolate.UnivariateSpline(freq, sigma, k=3, s=0) fsc_3sigma = optimize.minimize(lambda x: np.square(corr_func(x=x)-3.*sigma_spl(x)), freq[np.argmax(corr_func(x=freq)-3.*sigma_spl(freq) < 0)], method='Nelder-Mead') res['frc 3 sigma'] = 1./fsc_3sigma.x[0] # van Heel and Schatz, 2005, Fourier shell correlation threshold criteria half_bit = (0.2071 + 1.9102 / np.sqrt(counts)) / (1.2071 + 0.9102 / np.sqrt(counts)) half_bit_spl = interpolate.UnivariateSpline(freq, half_bit, k=3, s=0) fsc_half_bit = optimize.minimize(lambda x: np.square(corr_func(x=x)-half_bit_spl(x)), freq[np.argmax(corr_func(x=freq)-half_bit_spl(freq) < 0)], method='Nelder-Mead') res['frc half bit'] = 1./fsc_half_bit.x[0] # fsc_max = np.max([fsc_0143.x[0], fsc_2sigma.x[0], fsc_3sigma.x[0], fsc_5sigma.x[0], fsc_half_bit.x[0]]) # axes[1].set_xlim(0, np.min([2*fsc_max, im12_fft_flat_res.bin_edges[-1]])) # if not self.plot_graphs: # ioff() def plot(): frc_text = "" fig, axes = subplots(1,2,figsize=(10,4)) axes[0].plot(freq, corr) axes[0].plot(freq, corr_func(x=freq)) axes[0].axhline(0.143, ls='--', color='red') axes[0].axvline(fsc_0143.x[0], ls='--', color='red', label='1/7') frc_text += "\n1/7: {:.2f} nm".format(1./fsc_0143.x[0]) axes[0].plot(freq, 3*sigma_spl(freq), ls='--', color='pink') axes[0].axvline(fsc_3sigma.x[0], ls='--', color='pink', label='3 sigma') frc_text += "\n3 sigma: {:.2f} nm".format(1./fsc_3sigma.x[0]) axes[0].plot(freq, half_bit_spl(freq), ls='--', color='purple') axes[0].axvline(fsc_half_bit.x[0], ls='--', color='purple', label='1/2 bit') frc_text += "\n1/2 bit: {:.2f} nm".format(1./fsc_half_bit.x[0]) axes[0].legend() # axes[0].set_ylim(None, 1.1) x_ticklocs = axes[0].get_xticks() axes[0].set_xticklabels(["{:.1f}".format(1./i) for i in x_ticklocs]) axes[0].set_ylabel("FSC/FRC") axes[0].set_xlabel("Resol (nm)") axes[1].text(0.5, 0.5, frc_text, horizontalalignment='center', verticalalignment='center', transform=axes[1].transAxes) axes[1].set_axis_off() # if self.plot_graphs: # fig.show() # else: # ion() plot() self._namespace[self.output_frc_plot] = Plot(plot) rawdata = {'freq':freq, 'corr':corr, 'smooth':corr_func(x=freq), '1/7':np.ones_like(freq)/7, '3 sigma':3*sigma_spl(freq), '1/2 bit':half_bit_spl(freq)} return res, rawdata def save_to_file(self, namespace): if self.save_path is not "": try: np.savez_compressed(self.save_path, raw=namespace[self.output_frc_raw], results=namespace[self.output_frc_dict]) except Exception as e: raise e @staticmethod def BinData(indexes, data, statistic='mean', bins=10): # Calculates binned statistics. Supports complex number. if statistic == 'mean': func = np.mean elif statistic == 'sum': func = np.sum class Result(object): statistic = None bin_edges = None counts = None bins = np.linspace(indexes.min(), indexes.max(), bins) binned = np.zeros(len(bins)-1, dtype=data.dtype) counts = np.zeros(len(bins)-1, dtype=np.int) indexes_sort_arg = np.argsort(indexes.flatten()) indexes_sorted = indexes.flatten()[indexes_sort_arg] data_sorted = data.flatten()[indexes_sort_arg] edges_indexes = np.searchsorted(indexes_sorted, bins) for i in range(bins.shape[0]-1): values = data_sorted[edges_indexes[i]:edges_indexes[i+1]] binned[i] = func(values) counts[i] = len(values) res = Result() res.statistic = binned res.bin_edges = bins res.counts = counts return res @staticmethod def Sigmoid(n, c, x): res = 1 - 1 / (1 + np.exp(n*(-x+c))) return res
class RGBImageUpload(ImageUpload): """ Create RGB (png) image and upload it to an OMERO server, optionally attaching localization files. Parameters ---------- input_image : str name of image in the recipe namespace to upload input_localization_attachments : dict maps tabular types (keys) to attachment filenames (values). Tabular's will be saved as '.hdf' files and attached to the image filePattern : str pattern to determine name of image on OMERO server omero_dataset : str name of OMERO dataset to add the image to. If the dataset does not already exist it will be created. omero_project : str name of OMERO project to link the dataset to. If the project does not already exist it will be created. zoom : float how large to zoom the image scaling : str how to scale the intensity - one of 'min-max' or 'percentile' scaling_factor: float `percentile` scaling only - which percentile to use colorblind_friendly : bool Use cyan, magenta, and yellow rather than RGB. True, by default. Notes ----- OMERO server address and user login information must be stored in the user PYME config directory under plugins/config/pyme-omero, e.g. /Users/Andrew/.PYME/plugins/config/pyme-omero. The file should be yaml formatted with the following keys: user password address port [optional] The project/image/dataset will be owned by the user set in the yaml file. """ filePattern = '{file_stub}.png' scaling = Enum(['min-max', 'percentile']) scaling_factor = Float(0.95) zoom = Int(1) colorblind_friendly = Bool(True) def _save(self, image, path): from PIL import Image from PYME.IO.rgb_image import image_to_rgb, image_to_cmy if (self.colorblind_friendly and (image.data.shape[3] != 1)): im = image_to_cmy(image, zoom=self.zoom, scaling=self.scaling, scaling_factor=self.scaling_factor) else: im = image_to_rgb(image, zoom=self.zoom, scaling=self.scaling, scaling_factor=self.scaling_factor) rgb = Image.fromarray(im, mode='RGB') rgb.save(path)
class LabelRange(Filter): """Asigns a unique integer label to each contiguous region in the input mask. Throws away all regions which are outside of given number of pixel range. Also uses the number of sites from a second input channel to decide if region is retained, retaining only those with the number sites in a given range. """ inputSitesLabeled = Input( "sites") # sites and the main input must have the same shape! minRegionPixels = Int(10) maxRegionPixels = Int(100) minSites = Int(4) maxSites = Int(6) sitesAsMaxima = Bool(False) def filter(self, image, imagesites): #from PYME.util.shmarray import shmarray #import multiprocessing if self.processFramesIndividually: filt_ims = [] for chanNum in range(image.data.shape[3]): filt_ims.append( np.concatenate([ np.atleast_3d( self.applyFilter( image.data[:, :, i, chanNum].squeeze().astype('f'), imagesites.data[:, :, i, chanNum].squeeze().astype('f'), chanNum, i, image)) for i in range(image.data.shape[2]) ], 2)) else: filt_ims = [ np.atleast_3d( self.applyFilter( image.data[:, :, :, chanNum].squeeze().astype('f'), imagesites.data[:, :, :, chanNum].squeeze().astype('f'), chanNum, 0, image)) for chanNum in range(image.data.shape[3]) ] im = ImageStack(filt_ims, titleStub=self.outputName) im.mdh.copyEntriesFrom(image.mdh) im.mdh['Parent'] = image.filename self.completeMetadata(im) return im def execute(self, namespace): namespace[self.outputName] = self.filter( namespace[self.inputName], namespace[self.inputSitesLabeled]) def applyFilter(self, data, sites, chanNum, frNum, im): # siteLabels = self.recipe.namespace[self.sitesLabeled] mask = data > 0.5 labs, nlabs = ndimage.label(mask) rSize = self.minRegionPixels rMax = self.maxRegionPixels minSites = self.minSites maxSites = self.maxSites m2 = 0 * mask objs = ndimage.find_objects(labs) for i, o in enumerate(objs): r = labs[o] == i + 1 #print r.shape area = r.sum() if (area >= rSize) and (area <= rMax): if self.sitesAsMaxima: nsites = sites[o][r].sum() else: nsites = (np.unique(sites[o][r]) > 0).sum( ) # count the unique labels (excluding label 0 which is background) if (nsites >= minSites) and (nsites <= maxSites): m2[o] += r labs, nlabs = ndimage.label(m2 > 0) return labs def completeMetadata(self, im): im.mdh['Labelling.MinSize'] = self.minRegionPixels im.mdh['Labelling.MaxSize'] = self.maxRegionPixels im.mdh['Labelling.MinSites'] = self.minSites im.mdh['Labelling.MaxSites'] = self.maxSites
class CalculateFRCFromImages(CalculateFRCBase): """ Take a pair of images and calculates the fourier shell/ring correlation (FSC / FRC). Inputs ------ input_image_a : ImageStack First of two images. Outputs ------- output_fft_images_cc : ImageStack Fast Fourier transform original and cross-correlation images. output_frc_dict : dict FSC/FRC results. output_frc_plot : Plot Output plot of the FSC / FRC curve. output_frc_raw : dict Complete FSC/FRC results. Parameters ---------- image_b_path : File File path of the second of the two images. c_channel : int Color channel of the images to use. image_a_z : int Ignored unless flatten_z is True. In which case either select the z plane to use (>=0) or performs a maximum project (<0) for the first image. image_b_z : int Ignored unless flatten_z is True. In which case either select the z plane to use (>=0) or performs a maximum project (<0) for the second image. flatten_z : Bool If enabled ignores z information and only performs a FRC. pre_filter : string Methods to filter the images prior to Fourier transform. frc_smoothing_func : string Methods to smooth the FSC / FRC curve. cubic_smoothing : float Smoothing factor for cubic spline. multiprocessing : Bool Enables multiprocessing. save_path : File (Optional) File path to save output """ input_image_a = Input('input') # image_a_dim = Int(2) # image_a_index = Int(0) # image_b_dim = Int(2) # image_b_index = Int(1) image_b_path = File(info_text="Filepath of image to compare against. Leave blank to compare against currently opened image.") c_channel = Int(0) flatten_z = Bool(True) image_a_z = Int(-1) image_b_z = Int(-1) def execute(self, namespace): self._namespace = namespace import multiprocessing # from PYME.util import mProfile # mProfile.profileOn(["frc.py"]) if self.multiprocessing: proccess_count = np.clip(2, 1, multiprocessing.cpu_count()-1) self._pool = multiprocessing.Pool(processes=proccess_count) # image_pair = self.generate_image_pair(mapped_pipeline) # ims = namespace[self.input_images] image_a = namespace[self.input_image_a] if len(self.image_b_path.strip()) == 0: image_b = image_a else: image_b = ImageStack(filename=self.image_b_path) self._pixel_size_in_nm = np.zeros(3, dtype=np.float) self._pixel_size_in_nm[0] = image_a.mdh.voxelsize.x self._pixel_size_in_nm[1] = image_a.mdh.voxelsize.y try: self._pixel_size_in_nm[2] = image_a.mdh.voxelsize.z except: pass if image_a.mdh.voxelsize.units == 'um': self._pixel_size_in_nm *= 1.E3 # print(self._pixel_size_in_nm) # image_indices = [[self.image_a_dim, self.image_a_index], [self.image_b_dim, self.image_b_index]] # image_slices = list() # for i in xrange(2): # slices = [slice(None, None), slice(None, None)] # for j in xrange(2, image_indices[i][0]+1): # if j == image_indices[i][0]: # slices.append(slice(image_indices[i][1], image_indices[i][1]+1)) # else: # slices.append(slice(None, None)) # image_slices.append(slices) # # image_pair = [ims.data[image_slices[0]].squeeze(), ims.data[image_slices[1]].squeeze()] image_a_data = image_a.data[:,:,:,self.c_channel].squeeze() image_b_data = image_b.data[:,:,:,self.c_channel].squeeze() if self.flatten_z: print("2D mode. Slice if z index >= 0 otherwise max projection") if self.image_a_z >= 0: image_a_data = image_a_data[:,:,self.image_a_z] else: image_a_data = image_a_data.max(2) if self.image_b_z >= 0: image_b_data = image_b_data[:,:,self.image_b_z] else: image_b_data = image_b_data.max(2) # print(np.allclose(image_a_data, image_b_data)) image_pair = [image_a_data, image_b_data] # print(image_pair[0].shape) image_pair = self.preprocess_images(image_pair) frc_res, rawdata = self.calculate_FRC_from_images(image_pair, None) namespace[self.output_frc_dict] = frc_res namespace[self.output_frc_raw] = rawdata if self.multiprocessing: self._pool.close() self._pool.join() # mProfile.profileOff() # mProfile.report() self.save_to_file(namespace)
class RCCDriftCorrection(RCCDriftCorrectionBase): """ For localization data. Performs drift correction using cross-correlation, including redundant RCC from Wang et al. Optics Express 2014 22:13 (Bo Huang's RCC algorithm). Runtime will vary hugely depending on size of dataset and settings. ``cache_fft`` is necessary for large datasets. Inputs ------ input_for_correction : Tabular Dataset used to calculate drift. input_for_mapping : Tabular *Deprecated.* Dataset to correct. Outputs ------- output_drift : Tuple of arrays Drift results. output_drift_plot : Plot *Deprecated.* Plot of drift results. output_cross_cor : ImageStack Cross correlation images if ``debug_cor_file`` is not blank. outputName : Tabular *Deprecated.* Drift-corrected dataset. Parameters ---------- step : Int Setting for image construction. Step size between images window : Int Setting for image construction. Number of frames used per image. Should be equal or larger than step size. binsize : Float Setting for image construction. Pixel size. flatten_z : Bool Setting for image construction. Ignore z information if enabled. tukey_size : Float Setting for image construction. Shape parameter for Tukey filter (``scipy.signal.tukey``). cache_fft : File Use file as disk cache if provided. method : String Redundant, mean, or direct cross-correlation. shift_max : Float Rejection threshold for RCC. corr_window : Float Size of correlation window. Frames are only compared if within this frame range. N/A for DCC. multiprocessing : Float Enables multiprocessing. debug_cor_file : File Enables debugging. Use file as disk cache if provided. """ input_for_correction = Input('Localizations') input_for_mapping = Input('Localizations') # redundant cross-corelation, mean cross-correlation, direction cross-correlation step = Int(2500) window = Int(2500) binsize = Float(30) flatten_z = Bool() tukey_size = Float(0.25) outputName = Output('corrected_localizations') def calc_corr_drift_from_locs(self, x, y, z, t): # bin edges for histogram bx = np.arange(x.min(), x.max() + self.binsize + 1, self.binsize) by = np.arange(y.min(), y.max() + self.binsize + 1, self.binsize) bz = np.arange(z.min(), z.max() + self.binsize + 1, self.binsize) # pad bin length to odd number so image size is even if bx.shape[0] % 2 == 0: bx = np.concatenate([bx, [bx[-1] + bx[1] - bx[0]]]) if by.shape[0] % 2 == 0: by = np.concatenate([by, [by[-1] + by[1] - by[0]]]) if bz.shape[0] > 2 and bz.shape[0] % 2 == 0: bz = np.concatenate([bz, [bz[-1] + bz[1] - bz[0]]]) assert (bx.shape[0] % 2 == 1) and (by.shape[0] % 2 == 1), "Ops. Image not correctly padded to even size." # start time of all windows, allow partial window near end of pipeline time_values = np.arange(t.min(), t.max() + 1, self.step) # 2d array, start and end time of windows time_values = np.stack([time_values, np.clip(time_values + self.window, None, t.max())], axis=1) n_steps = time_values.shape[0] # center time of center for returning. last window may have different spacing time_values_mid = time_values.mean(axis=1) if (np.any(np.diff(t) < 0)): # in case pipeline is not sorted for whatever reason t_sort_arg = np.argsort(t) t = t[t_sort_arg] x = x[t_sort_arg] y = y[t_sort_arg] z = z[t_sort_arg] time_indexes = np.zeros_like(time_values, dtype=int) time_indexes[:, 0] = np.searchsorted(t, time_values[:, 0], side='left') time_indexes[:, 1] = np.searchsorted(t, time_values[:, 1]-1, side='right') # print('time indexes') # print(time_values) # print(time_values_mid) # print(time_indexes) # Fourier transformed (and binned) set of images to correlate against # one another # Crude way of swaping longest axis to the last for optimizing rfft performance. # Code changed for this is limited to this method. xyz = np.asarray([x, y, z]) bxyz = np.asarray([bx, by, bz]) dims_order = np.arange(len(xyz)) dims_length = np.asarray([len(b) for b in bxyz]) dims_largest_index = np.argmax(dims_length) dims_order[-1], dims_order[dims_largest_index] = dims_order[dims_largest_index], dims_order[-1] xyz = xyz[dims_order] bxyz = bxyz[dims_order] dims_length = dims_length[dims_order] # use memmap for caching if ft_cache is defined if self.cache_fft == "": ft_images = np.zeros((n_steps, dims_length[0]-1, dims_length[1]-1, (dims_length[2]-1)//2 + 1, ), dtype=np.complex) else: ft_images = np.memmap(self.cache_fft, dtype=np.complex, mode='w+', shape=(n_steps, dims_length[0]-1, dims_length[1]-1, (dims_length[2]-1)//2 + 1, )) print(ft_images.shape) print("{:,} bytes".format(ft_images.nbytes)) print("{:.2f} s. About to start heavy lifting.".format(time.time() - self._start_time)) # fill ft_images # if multiprocessing, can either use or not caching # if not multiprocessing, don't pass filenames for caching, just the memmap array is fine if self.multiprocessing: dt = ft_images.dtype sh = ft_images.shape args = [(i, xyz[:,slice(*ti)].T, bxyz, (self.cache_fft, dt, sh, i), self.tukey_size) for i, ti in enumerate(time_indexes)] for i, (j, res) in enumerate(self._pool.imap_unordered(calc_fft_from_locs_helper, args)): if self.cache_fft == "": ft_images[j] = res if ((i+1) % (n_steps//5) == 0): print("{:.2f} s. Completed calculating {} of {} total ft images.".format(time.time() - self._start_time, i+1, n_steps)) else: # For each window we wish to correlate... for i, ti in enumerate(time_indexes): # .. we generate an image and store ft of image t_slice = slice(*ti) ft_images[i] = calc_fft_from_locs(xyz[:,t_slice].T, bxyz, filter_size=self.tukey_size) if ((i+1) % (n_steps//5) == 0): print("{:.2f} s. Completed calculating {} of {} total ft images.".format(time.time() - self._start_time, i+1, n_steps)) print("{:.2f} s. Finished generating ft array.".format(time.time() - self._start_time)) print("{:,} bytes".format(ft_images.nbytes)) shifts, coefs = self.calc_corr_drift_from_ft_images(ft_images) # clean up of ft_images, potentially really large array if isinstance(ft_images, np.memmap): ft_images.flush() del ft_images # print(shifts) # print(coefs) return time_values_mid, self.binsize * shifts[:, dims_order], coefs def _execute(self, namespace): # from PYME.util import mProfile # self._start_time = time.time() self.trait_setq(**{"_start_time": time.time()}) print("Starting drift correction module.") if self.multiprocessing: proccess_count = np.clip(multiprocessing.cpu_count()-1, 1, None) self.trait_setq(**{"_pool": multiprocessing.Pool(processes=proccess_count)}) locs = namespace[self.input_for_correction] # mProfile.profileOn(['localisations.py', 'processing.py']) drift_res = self.calc_corr_drift_from_locs(locs['x'], locs['y'], locs['z'] * (0 if self.flatten_z else 1), locs['t']) t_shift, shifts = self.rcc(self.shift_max, *drift_res) # mProfile.profileOff() # mProfile.report() if self.multiprocessing: self._pool.close() self._pool.join() # convert frame-to-frame drift to drift from origin shifts = np.cumsum(shifts, 0) out = tabular.mappingFilter(namespace[self.input_for_mapping]) t_out = out['t'] # cubic interpolate with no smoothing dx = interpolate.CubicSpline(t_shift, shifts[:, 0])(t_out) dy = interpolate.CubicSpline(t_shift, shifts[:, 1])(t_out) dz = interpolate.CubicSpline(t_shift, shifts[:, 2])(t_out) if 'dx' in out.keys(): # getting around oddity with mappingFilter # addColumn adds a new column but also keeps the old column # __getitem__ returns the new column # but mappings usues the old column # Wrap with another level of mappingFilter so the new column becomes the 'old column' out.addColumn('dx', dx) out.addColumn('dy', dy) out.addColumn('dz', dz) out = tabular.mappingFilter(out) # out.mdh = namespace[self.input_localizations].mdh out.setMapping('x', 'x + dx') out.setMapping('y', 'y + dy') out.setMapping('z', 'z + dz') else: out.addColumn('dx', dx) out.addColumn('dy', dy) out.addColumn('dz', dz) out.setMapping('x', 'x + dx') out.setMapping('y', 'y + dy') out.setMapping('z', 'z + dz') # propagate metadata, if present try: out.mdh = locs.mdh except AttributeError: pass namespace[self.outputName] = out namespace[self.output_drift] = t_shift, shifts # non essential, only for plotting out drift data namespace[self.output_drift_plot] = Plot(partial(generate_drift_plot, t_shift, shifts)) namespace[self.output_cross_cor] = self._cc_image
class TriangleRenderLayer(EngineLayer): """ Layer for viewing triangle meshes. """ # properties to show in the GUI. Note that we also inherit 'visible' from BaseLayer vertexColour = CStr('constant', desc='Name of variable used to colour our points') cmap = Enum(*cm.cmapnames, default='gist_rainbow', desc='Name of colourmap used to colour faces') clim = ListFloat( [0, 1], desc='How our variable should be scaled prior to colour mapping') alpha = Float(1.0, desc='Face tranparency') method = Enum(*ENGINES.keys(), desc='Method used to display faces') normal_mode = Enum(['Per vertex', 'Per face']) display_normals = Bool(False) normal_scaling = Float(10.0) dsname = CStr( 'output', desc= 'Name of the datasource within the pipeline to use as a source of triangles (should be a TriangularMesh object)' ) _datasource_choices = List() _datasource_keys = List() def __init__(self, pipeline, method='wireframe', dsname='', context=None, **kwargs): EngineLayer.__init__(self, context=context, **kwargs) self._pipeline = pipeline self.engine = None self.cmap = 'gist_rainbow' self.x_key = 'x' # TODO - make these traits? self.y_key = 'y' self.z_key = 'z' self.xn_key = 'xn' self.yn_key = 'yn' self.zn_key = 'zn' self._bbox = None # define a signal so that people can be notified when we are updated (currently used to force a redraw when # parameters change) self.on_update = dispatch.Signal() # define responses to changes in various traits self.on_trait_change(self._update, 'vertexColour') self.on_trait_change(lambda: self.on_update.send(self), 'visible') self.on_trait_change( self.update, 'cmap, clim, alpha, dsname, normal_mode, display_normals, normal_scaling' ) self.on_trait_change(self._set_method, 'method') # update any of our traits which were passed as command line arguments self.set(**kwargs) # update datasource and method self.dsname = dsname if self.method == method: #make sure we still call _set_method even if we start with the default method self._set_method() else: self.method = method # if we were given a pipeline, connect ourselves to the onRebuild signal so that we can automatically update # ourselves if not self._pipeline is None: try: self._pipeline.onRebuild.connect(self.update) except AttributeError: pass @property def datasource(self): """ Return the datasource we are connected to (does not go through the pipeline for triangles_mesh). """ try: return self._pipeline.get_layer_data(self.dsname) except AttributeError: try: return self._pipeline[self.dsname] except AttributeError: return None #return self.datasource @property def _ds_class(self): # from PYME.experimental import triangle_mesh from PYME.experimental import _triangle_mesh as triangle_mesh return triangle_mesh.TrianglesBase def _set_method(self): self.engine = ENGINES[self.method](self._context) self.update() def _get_cdata(self): try: cdata = self.datasource[self.vertexColour] except (KeyError, TypeError): cdata = np.array([0, 1]) return cdata def _update(self, *args, **kwargs): #pass cdata = self._get_cdata() self.clim = [float(np.nanmin(cdata)), float(np.nanmax(cdata))] self.update(*args, **kwargs) def update(self, *args, **kwargs): try: self._datasource_choices = [ k for k, v in self._pipeline.dataSources.items() if isinstance(v, self._ds_class) ] except AttributeError: pass if not self.datasource is None: dks = [ 'constant', ] if hasattr(self.datasource, 'keys'): dks = dks + sorted(self.datasource.keys()) self._datasource_keys = dks if not (self.engine is None or self.datasource is None): print('lw update') self.update_from_datasource(self.datasource) self.on_update.send(self) @property def bbox(self): return self._bbox def update_from_datasource(self, ds): """ Pulls vertices/normals from a binary STL file. See PYME.IO.FileUtils.stl for more info. Calls update_data on the input. Parameters ---------- ds : PYME.experimental.triangular_mesh.TriangularMesh object Returns ------- None """ #t = ds.vertices[ds.faces] #n = ds.vertex_normals[ds.faces] x, y, z = ds.vertices[ds.faces].reshape(-1, 3).T if self.normal_mode == 'Per vertex': xn, yn, zn = ds.vertex_normals[ds.faces].reshape(-1, 3).T else: xn, yn, zn = np.repeat(ds.face_normals.T, 3, axis=1) if self.vertexColour in ['', 'constant']: c = np.ones(len(x)) clim = [0, 1] #elif self.vertexColour == 'vertex_index': # c = np.arange(0, len(x)) else: c = ds[self.vertexColour][ds.faces].ravel() clim = self.clim cmap = getattr(cm, self.cmap) alpha = float(self.alpha) # Do we have coordinates? Concatenate into vertices. if x is not None and y is not None and z is not None: vertices = np.vstack((x.ravel(), y.ravel(), z.ravel())) self._vertices = vertices.T.ravel().reshape(len(x.ravel()), 3) if not xn is None: self._normals = np.vstack( (xn.ravel(), yn.ravel(), zn.ravel())).T.ravel().reshape(len(x.ravel()), 3) else: self._normals = -0.69 * np.ones(self._vertices.shape) self._bbox = np.array( [x.min(), y.min(), z.min(), x.max(), y.max(), z.max()]) else: self._bbox = None # TODO: This temporarily sets all triangles to the color red. User should be able to select color. if c is None: c = np.ones(self._vertices.shape[0]) * 255 # vector of pink if clim is not None and c is not None and cmap is not None: cs_ = ((c - clim[0]) / (clim[1] - clim[0])) cs = cmap(cs_) if self.method in ['flat', 'tessel']: alpha = cs_ * alpha cs[:, 3] = alpha if self.method == 'tessel': cs = np.power(cs, 0.333) self._colors = cs.ravel().reshape(len(c), 4) else: # cs = None if not self._vertices is None: self._colors = np.ones((self._vertices.shape[0], 4), 'f') self._alpha = alpha self._color_map = cmap self._color_limit = clim def get_vertices(self): return self._vertices def get_normals(self): return self._normals def get_colors(self): return self._colors def get_color_map(self): return self._color_map @property def colour_map(self): return self._color_map def get_color_limit(self): return self._color_limit @property def default_view(self): from traitsui.api import View, Item, Group, InstanceEditor, EnumEditor from PYME.ui.custom_traits_editors import HistLimitsEditor, CBEditor return View([ Group([ Item('dsname', label='Data', editor=EnumEditor(name='_datasource_choices'), visible_when='_datasource_choices') ]), Item('method'), Item('normal_mode', visible_when='method=="shaded"'), Item('vertexColour', editor=EnumEditor(name='_datasource_keys'), label='Colour'), Group([ Item('clim', editor=HistLimitsEditor(data=self._get_cdata), show_label=False), ], visible_when='vertexColour != "constant"'), Group([ Item('cmap', label='LUT'), Item('alpha', visible_when='method in ["flat", "tessel", "shaded"]') ]) ], ) # buttons=['OK', 'Cancel']) def default_traits_view(self): return self.default_view
class CalculateFRCFromLocs(CalculateFRCBase): """ Generates a pair of images from localization data and calculates the fourier shell/ring correlation (FSC / FRC). Inputs ------ inputName : TabularBase Localization data. Outputs ------- outputName : TabularBase Localization data labeled with how it was divided (FRC_group). output_images : ImageStack Pair of 2D or 3D histogram rendered images. output_fft_images_cc : ImageStack Fast Fourier transform original and cross-correlation images. output_frc_dict : dict FSC/FRC results. output_frc_plot : Plot Output plot of the FSC / FRC curve. output_frc_raw : dict Complete FSC/FRC results. Parameters ---------- split_method : string Different methods of dividing data into halves. pixel_size_in_nm : float Pixel size used for rendering the images. flatten_z : Bool If enabled ignores z information and only performs a FRC. pre_filter : string Methods to filter the images prior to Fourier transform. frc_smoothing_func : string Methods to smooth the FSC / FRC curve. cubic_smoothing : float Smoothing factor for cubic spline. multiprocessing : Bool Enables multiprocessing. save_path : File (Optional) File path to save output """ inputName = Input('Localizations') split_method = Enum(['halves_random', 'halves_time', 'halves_100_time_chunk', 'halves_10_time_chunk', 'fixed_time', 'fixed_10_time_chunk']) pixel_size_in_nm = Float(5) # pre_filter = Enum(['Tukey_1/8', None]) # frc_smoothing_func = Enum(['Cubic Spline', 'Sigmoid', None]) # plot_graphs = Bool() # multiprocessing = Bool(True) flatten_z = Bool(True) outputName = Output('FRC_ds') output_images = Output('FRC_images') # output_fft_images = Output('FRC_fft_images') # output_frc_dict = Output('FRC_dict') # output_frc_plot = Output('FRC_plot') def execute(self, namespace): self._namespace = namespace import multiprocessing # from PYME.util import mProfile # mProfile.profileOn(["frc.py"]) if self.multiprocessing: proccess_count = np.clip(2, 1, multiprocessing.cpu_count()-1) self._pool = multiprocessing.Pool(processes=proccess_count) pipeline = namespace[self.inputName] mapped_pipeline = tabular.mappingFilter(pipeline) self._pixel_size_in_nm = self.pixel_size_in_nm * np.ones(3, dtype=np.float) image_pair = self.generate_image_pair(mapped_pipeline) image_pair = self.preprocess_images(image_pair) # Should use DensityMapping recipe eventually when it is ready. mdh = MetaDataHandler.NestedClassMDHandler() mdh['Rendering.Method'] = "np.histogramdd" if 'imageID' in pipeline.mdh.getEntryNames(): mdh['Rendering.SourceImageID'] = pipeline.mdh['imageID'] try: mdh['Rendering.SourceFilename'] = pipeline.resultsSource.h5f.filename except: pass mdh.Source = MetaDataHandler.NestedClassMDHandler(pipeline.mdh) mdh['Rendering.NEventsRendered'] = [image_pair[0].sum(), image_pair[1].sum()] mdh['voxelsize.units'] = 'um' mdh['voxelsize.x'] = self.pixel_size_in_nm * 1E-3 mdh['voxelsize.y'] = self.pixel_size_in_nm * 1E-3 ims = ImageStack(data=np.stack(image_pair, axis=-1), mdh=mdh) namespace[self.output_images] = ims # if self.plot_graphs: # from PYME.DSView.dsviewer import ViewIm3D # ViewIm3D(ims) frc_res, rawdata = self.calculate_FRC_from_images(image_pair, pipeline.mdh) # smoothed_frc = self.SmoothFRC(frc_freq, frc_corr) # # self.CalculateThreshold(frc_freq, frc_corr, smoothed_frc) namespace[self.output_frc_dict] = frc_res namespace[self.output_frc_raw] = rawdata if self.multiprocessing: self._pool.close() self._pool.join() # mProfile.profileOff() # mProfile.report() self.save_to_file(namespace) def generate_image_pair(self, mapped_pipeline): # Split localizations into 2 sets mask = np.zeros(mapped_pipeline['t'].shape, dtype=np.bool) if self.split_method == 'halves_time': sort_arg = np.argsort(mapped_pipeline['t']) mask[sort_arg[:len(sort_arg)//2]] = 1 elif self.split_method == 'halves_random': mask[:len(mask)//2] = 1 np.random.shuffle(mask) elif self.split_method == 'halves_100_time_chunk': sort_arg = np.argsort(mapped_pipeline['t']) chunksize = mask.shape[0] / 100.0 for i in range(50): mask[sort_arg[int(np.round(i*2*chunksize)):int(np.round((i*2+1)*chunksize))]] = 1 elif self.split_method == 'halves_10_time_chunk': sort_arg = np.argsort(mapped_pipeline['t']) chunksize = mask.shape[0] * 0.1 for i in range(5): mask[sort_arg[int(np.round(i*2*chunksize)):int(np.round((i*2+1)*chunksize))]] = 1 elif self.split_method == 'fixed_time': time_cutoff = (mapped_pipeline['t'].ptp() + 1) // 2 + mapped_pipeline['t'].min() mask[mapped_pipeline['t'] < time_cutoff] = 1 elif self.split_method == 'fixed_10_time_chunk': time_cutoffs = np.linspace(mapped_pipeline['t'].min(), mapped_pipeline['t'].max(), 11, dtype=np.float) for i in range((time_cutoffs.shape[0]-1)//2): mask[(mapped_pipeline['t'] > time_cutoffs[i*2]) & (mapped_pipeline['t'] < time_cutoffs[i*2+1])] = 1 if self.split_method.startswith('halves_'): assert np.abs(mask.sum() - (~mask).sum()) <= 1, "datasets uneven, {} vs {}".format(mask.sum(), (~mask).sum()) else: print("variable counts between images, {} vs {}".format(mask.sum(), (~mask).sum())) mapped_pipeline.addColumn("FRC_group", mask) self._namespace[self.outputName] = mapped_pipeline dims = ['x', 'y', 'z'] if self.flatten_z: dims.remove('z') # Simple hist2d binning bins = list() data = list() for i, dim in enumerate(dims): bins.append(np.arange(np.floor(mapped_pipeline[dim].min()).astype(int), np.ceil(mapped_pipeline[dim].max()).astype(int) + self.pixel_size_in_nm, self.pixel_size_in_nm)) data.append(mapped_pipeline[dim]) data = np.stack(data, axis=1) # print(bins) if self.multiprocessing: results = list() results.append(self._pool.apply_async(np.histogramdd, (data[mask==0],), kwds= {"bins":bins})) results.append(self._pool.apply_async(np.histogramdd, (data[mask==1],), kwds= {"bins":bins})) image_a = results[0].get()[0] image_b = results[1].get()[0] else: image_a = np.histogramdd(data[mask==0], bins=bins)[0] image_b = np.histogramdd(data[mask==1], bins=bins)[0] return image_a, image_b
class EnsembleFitROIs( ModuleBase): # Note that this should probably be moved somewhere else inputName = Input('ROIs') fit_type = CStr('LorentzianConvolvedSolidSphere_ensemblePSF') ensemble_parameter_guess = Float(50.) hold_ensemble_parameter_constant = Bool(False) outputName = Output('fit_results') def execute(self, namespace): inp = namespace[self.inputName] # generate RegionHandler from tables handler = RegionHandler() handler._load_from_list(inp) fit_class = region_fitters.fitters[self.fit_type] fitter = fit_class(handler) if self.hold_ensemble_parameter_constant: fitter.fit_profiles(self.ensemble_parameter_guess) else: fitter.ensemble_fit(self.ensemble_parameter_guess) res = tabular.RecArraySource(fitter.results) # propagate metadata, if present res.mdh = MetaDataHandler.NestedClassMDHandler( getattr(inp, 'mdh', None)) res.mdh['EnsembleFitROIs.FitType'] = self.fit_type res.mdh[ 'EnsembleFitROIs.EnsembleParameterGuess'] = self.ensemble_parameter_guess res.mdh[ 'EnsembleFitROIs.HoldEnsembleParamConstant'] = self.hold_ensemble_parameter_constant namespace[self.outputName] = res @property def _fitter_choices(self): return list(region_fitters.ensemble_fitters.keys()) #FIXME??? @property def default_view(self): from traitsui.api import View, Item from PYME.ui.custom_traits_editors import CBEditor return View(Item('inputName', editor=CBEditor(choices=self._namespace_keys)), Item('_'), Item('fit_type', editor=CBEditor(choices=self._fitter_choices)), Item('_'), Item('ensemble_parameter_guess'), Item('_'), Item('hold_ensemble_parameter_constant'), Item('_'), Item('outputName'), buttons=['OK']) @property def pipeline_view(self): from traitsui.api import View, Item from PYME.ui.custom_traits_editors import CBEditor return View( Item('fit_type', editor=CBEditor(choices=self._fitter_choices)), Item('_'), Item('ensemble_parameter_guess'), Item('_'), Item('hold_ensemble_parameter_constant'), ) @property def dsview_view(self): from traitsui.api import View, Item from PYME.ui.custom_traits_editors import CBEditor return View(Item('fit_type', editor=CBEditor(choices=self._fitter_choices)), Item('_'), Item('ensemble_parameter_guess'), Item('_'), Item('hold_ensemble_parameter_constant'), buttons=['OK'])
class PointCloudRenderLayer(EngineLayer): """ A layer for viewing point-cloud data, using one of 3 engines (indicated above) """ # properties to show in the GUI. Note that we also inherit 'visible' from BaseLayer vertexColour = CStr('', desc='Name of variable used to colour our points') point_size = Float(30.0, desc='Rendered size of the points in nm') cmap = Enum(*cm.cmapnames, default='gist_rainbow', desc='Name of colourmap used to colour points') clim = ListFloat([0, 1], desc='How our variable should be scaled prior to colour mapping') alpha = Float(1.0, desc='Point tranparency') method = Enum(*ENGINES.keys(), desc='Method used to display points') display_normals = Bool(False) normal_scaling = Float(10.0) dsname = CStr('output', desc='Name of the datasource within the pipeline to use as a source of points') _datasource_keys = List() _datasource_choices = List() def __init__(self, pipeline, method='points', dsname='', context=None, **kwargs): EngineLayer.__init__(self, context=context, **kwargs) self._pipeline = pipeline self.engine = None self.cmap = 'gist_rainbow' self.x_key = 'x' #TODO - make these traits? self.y_key = 'y' self.z_key = 'z' self.xn_key = 'xn' self.yn_key = 'yn' self.zn_key = 'zn' self._bbox = None # define a signal so that people can be notified when we are updated (currently used to force a redraw when # parameters change) self.on_update = dispatch.Signal() # define responses to changes in various traits self.on_trait_change(self._update, 'vertexColour') self.on_trait_change(lambda: self.on_update.send(self), 'visible') self.on_trait_change(self.update, 'cmap, clim, alpha, dsname, point_size') self.on_trait_change(self._set_method, 'method') # update any of our traits which were passed as command line arguments self.set(**kwargs) # update datasource name and method #logger.debug('Setting dsname and method') self.dsname = dsname self.method = method self._set_method() # if we were given a pipeline, connect ourselves to the onRebuild signal so that we can automatically update # ourselves if not self._pipeline is None: self._pipeline.onRebuild.connect(self.update) @property def datasource(self): """ Return the datasource we are connected to (through our dsname property). """ return self._pipeline.get_layer_data(self.dsname) def _set_method(self): #logger.debug('Setting layer method to %s' % self.method) self.engine = ENGINES[self.method](self._context) self.update() def _get_cdata(self): try: cdata = self.datasource[self.vertexColour] except KeyError: cdata = np.array([0, 1]) return cdata def _update(self, *args, **kwargs): cdata = self._get_cdata() self.clim = [float(np.nanmin(cdata)), float(np.nanmax(cdata))+1e-9] #self.update(*args, **kwargs) def update(self, *args, **kwargs): print('lw update') self._datasource_choices = self._pipeline.layer_data_source_names if not self.datasource is None: self._datasource_keys = sorted(self.datasource.keys()) if not (self.engine is None or self.datasource is None): self.update_from_datasource(self.datasource) self.on_update.send(self) @property def bbox(self): return self._bbox def update_from_datasource(self, ds): x, y = ds[self.x_key], ds[self.y_key] if not self.z_key is None: try: z = ds[self.z_key] except KeyError: z = 0*x else: z = 0 * x if not self.vertexColour == '': c = ds[self.vertexColour] else: c = 0*x if self.xn_key in ds.keys(): xn, yn, zn = ds[self.xn_key], ds[self.yn_key], ds[self.zn_key] self.update_data(x, y, z, c, cmap=getattr(cm, self.cmap), clim=self.clim, alpha=self.alpha, xn=xn, yn=yn, zn=zn) else: self.update_data(x, y, z, c, cmap=getattr(cm, self.cmap), clim=self.clim, alpha=self.alpha) def update_data(self, x=None, y=None, z=None, colors=None, cmap=None, clim=None, alpha=1.0, xn=None, yn=None, zn=None): self._vertices = None self._normals = None self._colors = None self._color_map = None self._color_limit = 0 self._alpha = 0 if x is not None and y is not None and z is not None and len(x) > 0: vertices = np.vstack((x.ravel(), y.ravel(), z.ravel())) vertices = vertices.T.ravel().reshape(len(x.ravel()), 3) if not xn is None: normals = np.vstack((xn.ravel(), yn.ravel(), zn.ravel())).T.ravel().reshape(len(x.ravel()), 3) else: normals = -0.69 * np.ones(vertices.shape) self._bbox = np.array([x.min(), y.min(), z.min(), x.max(), y.max(), z.max()]) else: vertices = None normals = None self._bbox = None if clim is not None and colors is not None and clim is not None: cs_ = ((colors - clim[0]) / (clim[1] - clim[0])) cs = cmap(cs_) cs[:, 3] = alpha cs = cs.ravel().reshape(len(colors), 4) else: if not vertices is None: cs = np.ones((vertices.shape[0], 4), 'f') else: cs = None self.set_values(vertices, normals, cs, cmap, clim, alpha) def set_values(self, vertices=None, normals=None, colors=None, color_map=None, color_limit=None, alpha=None): if vertices is not None: self._vertices = vertices if normals is not None: self._normals = normals if color_map is not None: self._color_map = color_map if colors is not None: self._colors = colors if color_limit is not None: self._color_limit = color_limit if alpha is not None: self._alpha = alpha def get_vertices(self): return self._vertices def get_normals(self): return self._normals def get_colors(self): return self._colors def get_color_map(self): return self._color_map @property def colour_map(self): return self._color_map def get_color_limit(self): return self._color_limit @property def default_view(self): from traitsui.api import View, Item, Group, InstanceEditor, EnumEditor, TextEditor from PYME.ui.custom_traits_editors import HistLimitsEditor, CBEditor return View([Group([Item('dsname', label='Data', editor=EnumEditor(name='_datasource_choices')), ]), Item('method'), Item('vertexColour', editor=EnumEditor(name='_datasource_keys'), label='Colour', visible_when='cmap not in ["R", "G", "B", "C", "M","Y", "K"]'), Group([Item('clim', editor=HistLimitsEditor(data=self._get_cdata, update_signal=self.on_update), show_label=False), ], visible_when='cmap not in ["R", "G", "B", "C", "M","Y", "K"]'), Group(Item('cmap', label='LUT'), Item('alpha', visible_when="method in ['pointsprites', 'transparent_points']", editor=TextEditor(auto_set=False, enter_set=True, evaluate=float)), Item('point_size', label=u'Point\u00A0size', editor=TextEditor(auto_set=False, enter_set=True, evaluate=float)))]) #buttons=['OK', 'Cancel']) def default_traits_view(self): return self.default_view
class LUTOverlayLayer(OverlayLayer): """ This OverlayLayer produces a bar that indicates the given color map. """ show_bounds = Bool(False) def __init__(self, offset=None, **kwargs): """ Parameters ---------- offset offset of the canvas origin where it should be drawn. Currently only offset[0] is used """ if not offset: offset = [10, 10] OverlayLayer.__init__(self, offset, **kwargs) self.set_offset(offset) self._lut_width_px = 10.0 self._border_colour = [.5, .5, 0] self.set_shader_program(DefaultShaderProgram) self._labels = {} def _get_label(self, layer): from . import text try: return self._labels[layer] except KeyError: self._labels[layer] = [text.Text(), text.Text()] return self._labels[layer] def render(self, gl_canvas): if not self.visible: return self._clear_shader_clipping() labels = [] with self.shader_program: glDisable(GL_DEPTH_TEST) glDisable(GL_LIGHTING) glDisable(GL_BLEND) #view_size_x = gl_canvas.xmax - gl_canvas.xmin #view_size_y = gl_canvas.ymax - gl_canvas.ymin view_size_x, view_size_y = gl_canvas.Size # upper right y lb_ur_y = .1 * view_size_y # lower right y lb_lr_y = 0.9 * view_size_y lb_len = lb_lr_y - lb_ur_y lb_width = self._lut_width_px #* view_size_x / gl_canvas.Size[0] visible_layers = [ l for l in gl_canvas.layers if (getattr(l, 'visible', True) and getattr(l, 'show_lut', True)) ] for j, l in enumerate(visible_layers): cmap = l.colour_map # upper right x lb_ur_x = view_size_x - self.get_offset( )[0] - j * 1.5 * lb_width # upper left x lb_ul_x = lb_ur_x - lb_width #print(lb_ur_x, lb_ur_y, lb_lr_y, lb_ul_x, lb_width, lb_len, view_size_x, view_size_y) glBegin(GL_QUAD_STRIP) for i in numpy.arange(0, 1.01, .01): glColor3fv(cmap(i)[:3]) glVertex2f(lb_ul_x, lb_ur_y + (1. - i) * lb_len) glVertex2f(lb_ur_x, lb_ur_y + (1. - i) * lb_len) glEnd() glBegin(GL_LINE_LOOP) glColor3fv(self._border_colour) glVertex2f(lb_ul_x, lb_lr_y) glVertex2f(lb_ur_x, lb_lr_y) glVertex2f(lb_ur_x, lb_ur_y) glVertex2f(lb_ul_x, lb_ur_y) glEnd() if hasattr(l, 'clim') and self.show_bounds: tl, tu = self._get_label(l) cl, cu = l.clim tu.text = '%.3G' % cu tl.text = '%.3G' % cl xc = lb_ur_x - 0.5 * lb_width tu.pos = (xc - tu._w / 2, lb_ur_y - tu._h) tl.pos = (xc - tl._w / 2, lb_lr_y) labels.extend([tl, tu]) for l in labels: l.render(gl_canvas)
class FiducialTrack(ModuleBase): """ Extract average fiducial track from input pipeline Parameters ---------- radiusMultiplier: this number is multiplied with error_x to obtain search radius for clustering timeWindow: the window along the time dimension used for clustering filterScale: the size of the filter kernel used to smooth the resulting average fiducial track filterMethod: enumrated choice of filter methods for smoothing operation (Gaussian, Median or Uniform kernel) Notes ----- Output is a new pipeline with added fiducial_x, fiducial_y columns """ import PYMEcs.Analysis.trackFiducials as tfs inputName = Input('filtered') radiusMultiplier = Float(5.0) timeWindow = Int(25) filterScale = Float(11) filterMethod = Enum(tfs.FILTER_FUNCS.keys()) clumpMinSize = Int(50) singleFiducial = Bool(True) outputName = Output('fiducialAdded') def execute(self, namespace): import PYMEcs.Analysis.trackFiducials as tfs inp = namespace[self.inputName] mapped = tabular.mappingFilter(inp) if self.singleFiducial: # if all data is from a single fiducial we do not need to align # we then avoid problems with incomplete tracks giving rise to offsets between # fiducial track fragments align = False else: align = True t, x, y, z, isFiducial = tfs.extractTrajectoriesClump( inp, clumpRadiusVar='error_x', clumpRadiusMultiplier=self.radiusMultiplier, timeWindow=self.timeWindow, clumpMinSize=self.clumpMinSize, align=align) rawtracks = (t, x, y, z) tracks = tfs.AverageTrack(inp, rawtracks, filter=self.filterMethod, filterScale=self.filterScale, align=align) # add tracks for all calculated dims to output for dim in tracks.keys(): mapped.addColumn('fiducial_%s' % dim, tracks[dim]) mapped.addColumn('isFiducial', isFiducial) # propogate metadata, if present try: mapped.mdh = inp.mdh except AttributeError: pass namespace[self.outputName] = mapped @property def hide_in_overview(self): return ['columns']
class EnsembleFitProfiles(ModuleBase): inputName = Input('line_profiles') fit_type = CStr(list(profile_fitters.ensemble_fitters.keys())[0]) ensemble_parameter_guess = Float(50.) hold_ensemble_parameter_constant = Bool(False) outputName = Output('fit_results') def execute(self, namespace): inp = namespace[self.inputName] # generate LineProfileHandler from tables handler = LineProfileHandler() handler._load_profiles_from_list(inp) fit_class = profile_fitters.ensemble_fitters[self.fit_type] self.fitter = fit_class(handler) if self.hold_ensemble_parameter_constant: self.fitter.fit_profiles(self.ensemble_parameter_guess) else: self.fitter.ensemble_fit(self.ensemble_parameter_guess) res = tabular.RecArraySource(self.fitter.results) # propagate metadata, if present res.mdh = MetaDataHandler.NestedClassMDHandler( getattr(inp, 'mdh', None)) res.mdh['EnsembleFitProfiles.FitType'] = self.fit_type res.mdh[ 'EnsembleFitProfiles.EnsembleParameterGuess'] = self.ensemble_parameter_guess res.mdh[ 'EnsembleFitProfiles.HoldEnsembleParamConstant'] = self.hold_ensemble_parameter_constant namespace[self.outputName] = res @property def _fitter_choices(self): return list(profile_fitters.ensemble_fitters.keys()) @property def default_view(self): from traitsui.api import View, Item from PYME.ui.custom_traits_editors import CBEditor return View(Item('inputName', editor=CBEditor(choices=self._namespace_keys)), Item('_'), Item('fit_type', editor=CBEditor(choices=self._fitter_choices)), Item('_'), Item('ensemble_parameter_guess'), Item('_'), Item('hold_ensemble_parameter_constant'), Item('_'), Item('outputName'), buttons=['OK']) @property def pipeline_view(self): from traitsui.api import View, Item from PYME.ui.custom_traits_editors import CBEditor return View( Item('fit_type', editor=CBEditor(choices=self._fitter_choices)), Item('_'), Item('ensemble_parameter_guess'), Item('_'), Item('hold_ensemble_parameter_constant'), ) @property def dsview_view(self): from traitsui.api import View, Item from PYME.ui.custom_traits_editors import CBEditor return View(Item('fit_type', editor=CBEditor(choices=self._fitter_choices)), Item('_'), Item('ensemble_parameter_guess'), Item('_'), Item('hold_ensemble_parameter_constant'), buttons=['OK'])
class AlignPSF(ModuleBase): """ Align PSF stacks by redundant cross correlation. """ inputName = Input('psf_cropped') normalize_z = Bool(True) tukey = Float(0.50) rcc_tolerance = Float(5.0) z_crop_half_roi = Int(15) peak_detect = Enum(['Gaussian', 'RBF']) debug = Bool(False) output_cross_corr_images = Output('cross_cor_img') output_cross_corr_images_fitted = Output('cross_cor_img_fitted') output_images = Output('psf_aligned') def execute(self, namespace): self._namespace = namespace ims = namespace[self.inputName] # X, Y, Z, 'C' psf_stack = ims.data[:,:,:,:] z_slice = slice(psf_stack.shape[2]//2-self.z_crop_half_roi, psf_stack.shape[2]//2+self.z_crop_half_roi+1) cleaned_psf_stack = self.normalize_images(psf_stack[:,:,z_slice,:]) if self.tukey > 0: masks = [signal.tukey(dim_len, self.tukey) for dim_len in cleaned_psf_stack.shape[:3]] masks = np.product(np.meshgrid(*masks, indexing='ij'), axis=0) cleaned_psf_stack *= masks[:,:,:,None] drifts = self.calculate_shifts(cleaned_psf_stack, self.rcc_tolerance * 1E3 / ims.mdh['voxelsize.x']) # print drifts namespace[self.output_images] = ImageStack(self.shift_images(cleaned_psf_stack if self.debug else psf_stack, drifts), mdh=ims.mdh) def normalize_images(self, psf_stack): # in case it is already bg subtracted cleaned_psf_stack = np.clip(psf_stack, 0, None) # substact bg per stack cleaned_psf_stack -= cleaned_psf_stack.min(axis=(0,1,2), keepdims=True) if self.normalize_z: # normalize intensity per plane cleaned_psf_stack /= cleaned_psf_stack.max(axis=(0,1), keepdims=True) / 1.05 else: # normalize intensity per psf stack cleaned_psf_stack /= cleaned_psf_stack.max(axis=(0,1,2), keepdims=True) / 1.05 cleaned_psf_stack -= 0.05 np.clip(cleaned_psf_stack, 0, None, cleaned_psf_stack) return cleaned_psf_stack def calculate_shifts(self, psf_stack, drift_tolerance): n_steps = psf_stack.shape[3] coefs_size = n_steps * (n_steps-1) / 2 coefs = np.zeros((coefs_size, n_steps-1)) shifts = np.zeros((coefs_size, 3)) output_cross_corr_images = np.zeros((psf_stack.shape[0], psf_stack.shape[1], psf_stack.shape[2], coefs_size), dtype=np.float) output_cross_corr_images_fitted = np.zeros((psf_stack.shape[0], psf_stack.shape[1], psf_stack.shape[2], coefs_size), dtype=np.float) counter = 0 for i in np.arange(0, n_steps - 1): for j in np.arange(i+1, n_steps): coefs[counter, i:j] = 1 print "compare {} to {}".format(i, j) correlate_result = signal.correlate(psf_stack[:,:,:,i], psf_stack[:,:,:,j], mode="same") correlate_result -= correlate_result.min() correlate_result /= correlate_result.max() threshold = 0.50 correlate_result[correlate_result<threshold] = np.nan labeled_image, labeled_counts = ndimage.label(~np.isnan(correlate_result)) # print(labeled_counts) # protects against > 1 peak in the cross correlation results # shouldn't happen anyway, but at least avoid fitting a single to multi-modal data if labeled_counts > 1: max_order = np.argsort(ndimage.maximum(correlate_result, labeled_image, np.arange(labeled_counts)+1))+1 correlate_result[labeled_image!=max_order[0]] = np.nan output_cross_corr_images[:,:,:,counter] = np.nan_to_num(correlate_result) dims = list() for _, dim in enumerate(correlate_result.shape): dims.append(np.arange(dim)) dims[-1] = dims[-1] - dims[-1].mean() # peaks = np.nonzero(correlate_result==np.nanmax(correlate_result)) if self.peak_detect == "Gaussian": res = optimize.least_squares(guassian_nd_error, [1, 0, 0, 5., 0, 5., 0, 30.], args=(dims, correlate_result)) output_cross_corr_images_fitted[:,:,:,counter] = gaussian_nd(res.x, dims) # print("Gaussian") # print("chi2: {}".format(np.sum(np.square(res.fun))/(res.fun.shape[0]-8))) # print("fitted parameters: {}".format(res.x)) # res = optimize.least_squares(guassian_sq_nd_error, # [1, 0, 0, 3., 0, 3., 0, 20.], # args=(dims, correlate_result)) # output_cross_corr_images_fitted[:,:,:,counter] = gaussian_sq_nd(res.x, dims) # print("Gaussian 2") # print("chi2: {}".format(np.sum(np.square(res.fun))/(res.fun.shape[0]-8))) # print("fitted parameters: {}".format(res.x)) # # res = optimize.least_squares(lorentzian_nd_error, # [1, 0, 0, 2., 0, 2., 0, 10.], # args=(dims, correlate_result)) # output_cross_corr_images_fitted[:,:,:,counter] = lorentzian_nd(res.x, dims) # print("lorentzian") # print("chi2: {}".format(np.sum(np.square(res.fun))/(res.fun.shape[0]-8))) # print("fitted parameters: {}".format(res.x)) shifts[counter, 0] = res.x[2] shifts[counter, 1] = res.x[4] shifts[counter, 2] = res.x[6] elif self.peak_detect == "RBF": rbf_interpolator = build_rbf(dims, correlate_result) res = optimize.minimize(rbf_nd_error, [correlate_result.shape[0]*0.5, correlate_result.shape[1]*0.5, correlate_result.shape[2]*0.5], args=rbf_interpolator) output_cross_corr_images_fitted[:,:,:,counter] = rbf_nd(rbf_interpolator, dims) # print(res.x) shifts[counter, :] = res.x else: raise Exception("peak founding method not recognised") # print("fitted parameters: {}".format(res.x)) counter += 1 self._namespace[self.output_cross_corr_images] = ImageStack(output_cross_corr_images) self._namespace[self.output_cross_corr_images_fitted] = ImageStack(output_cross_corr_images_fitted) drifts = np.matmul(np.linalg.pinv(coefs), shifts) residuals = np.matmul(coefs, drifts) - shifts residuals_dist = np.linalg.norm(residuals, axis=1) # shift_max = self.rcc_tolerance * 1E3 / mdh['voxelsize.x'] shift_max = drift_tolerance # Sort and mask residual errors residuals_arg = np.argsort(-residuals_dist) residuals_arg = residuals_arg[residuals_dist[residuals_arg] > shift_max] # Remove coefs rows # Descending from largest residuals to small # Only if matrix remains full rank coefs_temp = np.empty_like(coefs) counter = 0 for i, index in enumerate(residuals_arg): coefs_temp[:] = coefs coefs_temp[index, :] = 0 if np.linalg.matrix_rank(coefs_temp) == coefs.shape[1]: coefs[:] = coefs_temp # print("index {} with residual of {} removed".format(index, residuals_dist[index])) counter += 1 else: print("Could not remove all residuals over shift_max threshold.") break print("removed {} in total".format(counter)) drifts = np.matmul(np.linalg.pinv(coefs), shifts) drifts = np.pad(drifts, [[1,0],[0,0]], 'constant', constant_values=0) np.cumsum(drifts, axis=0, out=drifts) psf_stack_mean = psf_stack / psf_stack.mean(axis=(0,1,2), keepdims=True) psf_stack_mean = psf_stack_mean.mean(axis=3) psf_stack_mean *= psf_stack_mean > psf_stack_mean.max() * 0.5 center_offset = ndimage.center_of_mass(psf_stack_mean) - np.asarray(psf_stack_mean.shape)*0.5 # print(center_offset) # print drifts.shape # print stats.trim_mean(drifts, 0.25, axis=0) # drifts = drifts - stats.trim_mean(drifts, 0.25, axis=0) drifts = drifts - center_offset if True: try: # from matplotlib import pyplot fig, axes = pyplot.subplots(1, 2, figsize=(6,3)) # new_residuals = np.matmul(coefs, drifts) - shifts # new_residuals_dist = np.linalg.norm(new_residuals, axis=1) # # print new_residuals_dist # pyplot.hist(new_residuals_dist[coefs.any(axis=1)], 100) # print drifts # limits = np.max(np.abs(drifts), axis=0) axes[0].scatter(drifts[:,0], drifts[:,1], s=50) # axes[0].set_xlim(-limits[0], limits[0]) # axes[0].set_ylim(-limits[1], limits[1]) axes[0].set_xlabel('x') axes[0].set_ylabel('y') axes[1].scatter(drifts[:,0], drifts[:,2], s=50) axes[1].set_xlabel('x') axes[1].set_ylabel('z') for ax in axes: # ax.set_xlim(-1, 1) # ax.set_ylim(-1, 1) ax.axvline(0, color='red', ls='--') ax.axhline(0, color='red', ls='--') fig.tight_layout() except Exception as e: print e return drifts def shift_images(self, psf_stack, shifts): kx = (np.fft.fftfreq(psf_stack.shape[0])) ky = (np.fft.fftfreq(psf_stack.shape[1])) kz = (np.fft.fftfreq(psf_stack.shape[2])) kx, ky, kz = np.meshgrid(kx, ky, kz, indexing='ij') shifted_images = np.zeros_like(psf_stack) for i in np.arange(psf_stack.shape[3]): psf = psf_stack[:,:,:,i] ft_image = np.fft.fftn(psf) shift = shifts[i] shifted_images[:,:,:,i] = np.abs(np.fft.ifftn(ft_image*np.exp(-2j*np.pi*(kx*shift[0] + ky*shift[1] + kz*shift[2])))) # shifted_images.append(shifted_image) return shifted_images
class RCCDriftCorrectionBase(CacheCleanupModule): """ Performs drift correction using redundant cross-correlation from Wang et al. Optics Express 2014 22:13 (Bo Huang's RCC algorithm). Base class for other RCC recipes. Can take cached fft input (as filename, not an 'input'). Only output drift as tuple of time points, and drift amount. Currently not registered by itself since not very usefule. """ cache_fft = File("rcc_cache.bin") method = Enum(['RCC', 'MCC', 'DCC']) # redundant cross-corelation, mean cross-correlation, direct cross-correlation shift_max = Float(5) # nm corr_window = Int(5) multiprocessing = Bool() debug_cor_file = File() output_drift = Output('drift') output_drift_plot = Output('drift_plot') # if debug_cor_file not blank, filled with imagestack of cross correlation output_cross_cor = Output('cross_cor') def calc_corr_drift_from_ft_images(self, ft_images): n_steps = ft_images.shape[0] # Matrix equation coefficient matrix # Shape can be predetermined based on method if self.method == "DCC": coefs_size = n_steps - 1 elif self.corr_window > 0: coefs_size = n_steps * self.corr_window - self.corr_window * ( self.corr_window + 1) // 2 else: coefs_size = n_steps * (n_steps - 1) // 2 coefs = np.zeros((coefs_size, n_steps - 1)) shifts = np.zeros((coefs_size, 3)) counter = 0 ft_1_cache = list() ft_2_cache = list() autocor_shift_cache = list() # print self.debug_cor_file if not self.debug_cor_file == "": cc_file_shape = [ shifts.shape[0], ft_images.shape[1], ft_images.shape[2], (ft_images.shape[3] - 1) * 2 ] # flatten shortest dimension to reduce cross correlation to 2d images for easier debugging min_arg = min(enumerate(cc_file_shape[1:]), key=lambda x: x[1])[0] + 1 cc_file_shape.pop(min_arg) cc_file_args = (self.debug_cor_file, np.float, tuple(cc_file_shape)) cc_file = np.memmap(cc_file_args[0], dtype=cc_file_args[1], mode="w+", shape=cc_file_args[2]) # del cc_file cc_args = zip(range(shifts.shape[0]), (cc_file_args, ) * shifts.shape[0]) else: cc_args = (None, ) * shifts.shape[0] # For each ft image, calculate correlation for i in np.arange(0, n_steps - 1): if self.method == "DCC" and i > 0: break ft_1 = ft_images[i, :, :] autocor_shift = calc_shift(ft_1, ft_1) for j in np.arange(i + 1, n_steps): if (self.method != "DCC") and (self.corr_window > 0) and ( j - i > self.corr_window): break ft_2 = ft_images[j, :, :] coefs[counter, i:j] = 1 # if multiprocessing, use cache when defined if self.multiprocessing: # if reading ft_images from cache, replace ft_1 and ft_2 with their indices if not self.cache_fft == "": ft_1 = i ft_2 = j ft_1_cache.append(ft_1) ft_2_cache.append(ft_2) autocor_shift_cache.append(autocor_shift) else: shifts[counter, :] = calc_shift(ft_1, ft_2, autocor_shift, None, cc_args[counter]) if ((counter + 1) % max(coefs_size // 5, 1) == 0): print( "{:.2f} s. Completed calculating {} of {} total shifts." .format(time.time() - self._start_time, counter + 1, coefs_size)) counter += 1 if self.multiprocessing: args = zip( range(len(autocor_shift_cache)), ft_1_cache, ft_2_cache, autocor_shift_cache, len(ft_1_cache) * ((self.cache_fft, ft_images.dtype, ft_images.shape), ), cc_args) for i, (j, res) in enumerate( self._pool.imap_unordered(calc_shift_helper, args)): shifts[j, ] = res if ((i + 1) % max(coefs_size // 5, 1) == 0): print( "{:.2f} s. Completed calculating {} of {} total shifts." .format(time.time() - self._start_time, i + 1, coefs_size)) print("{:.2f} s. Finished calculating all shifts.".format( time.time() - self._start_time)) print("{:,} bytes".format(coefs.nbytes)) print("{:,} bytes".format(shifts.nbytes)) if not self.debug_cor_file == "": # move time axis for ImageStack cc_file = np.moveaxis(cc_file, 0, 2) self.trait_setq(**{"_cc_image": ImageStack(data=cc_file.copy())}) del cc_file else: self.trait_setq(**{"_cc_image": None}) assert (np.all(np.any( coefs, axis=1))), "Coefficient matrix filled less than expected." mask = np.where(~np.isnan(shifts).any(axis=1))[0] if len(mask) < shifts.shape[0]: print("Removed {} cross correlations due to bad/missing data?". format(shifts.shape[0] - len(mask))) coefs = coefs[mask, :] shifts = shifts[mask, :] assert (coefs.shape[0] > 0) and ( np.linalg.matrix_rank(coefs) == n_steps - 1), "Something went wrong with coefficient matrix. Not full rank." return shifts, coefs # shifts.shape[0] is n_steps - 1 def rcc( self, shift_max, t_shift, shifts, coefs, ): """ Should probably rename function. Takes cross correlation results and calculates shifts. """ print("{:.2f} s. About to start solving shifts array.".format( time.time() - self._start_time)) # Estimate drift drifts = np.matmul(np.linalg.pinv(coefs), shifts) # print(t_shift) # print(drifts) print("{:.2f} s. Done solving shifts array.".format(time.time() - self._start_time)) if self.method == "RCC": # Calculate residual errors residuals = np.matmul(coefs, drifts) - shifts residuals_dist = np.linalg.norm(residuals, axis=1) # Sort and mask residual errors residuals_arg = np.argsort(-residuals_dist) residuals_arg = residuals_arg[ residuals_dist[residuals_arg] > shift_max] # Remove coefs rows # Descending from largest residuals to small # Only if matrix remains full rank coefs_temp = np.empty_like(coefs) counter = 0 for i, index in enumerate(residuals_arg): coefs_temp[:] = coefs coefs_temp[index, :] = 0 if np.linalg.matrix_rank(coefs_temp) == coefs.shape[1]: coefs[:] = coefs_temp # print("index {} with residual of {} removed".format(index, residuals_dist[index])) counter += 1 else: print( "Could not remove all residuals over shift_max threshold." ) break print("removed {} in total".format(counter)) # Estimate drift again drifts = np.matmul(np.linalg.pinv(coefs), shifts) print("{:.2f} s. RCC completed. Repeated solving shifts array.". format(time.time() - self._start_time)) # pad with 0 drift for first time point drifts = np.pad(drifts, [[1, 0], [0, 0]], 'constant', constant_values=0) return t_shift, drifts def _execute(self, namespace): # dervied versions of RCC need to override this method # 'execute' of this RCC base class is not throughly tested as its use is probably quite limited. # from PYME.util import mProfile self._start_time = time.time() print("Starting drift correction module.") if self.multiprocessing: proccess_count = np.clip(multiprocessing.cpu_count() - 1, 1, None) self._pool = multiprocessing.Pool(processes=proccess_count) # mProfile.profileOn(['localisations.py']) drift_res = self.calc_corr_drift_from_ft_images(self.cache_fft) t_shift, shifts = self.rcc(self.shift_max, *drift_res) # mProfile.profileOff() # mProfile.report() if self.multiprocessing: self._pool.close() self._pool.join() # convert frame-to-frame drift to drift from origin shifts = np.cumsum(shifts, 0) namespace[self.output_drift] = t_shift, shifts
class AveragePSF(ModuleBase): """ Input stacks of PSF and return the (normalized) average PSF. Additional filter based on max error/residual between image and averaged image. """ inputName = Input('psf_aligned') normalize_intensity = Bool(False) # normalize_z = Bool(False) output_var_image = Output('psf_var') # smoothing_method = Enum(['RBF', 'Gaussian']) # output_var_image_norm = Output('psf_var_norm') gaussian_filter = List(Float, [0, 0, 0], 3, 3) residual_threshold = Float(0.1) output_images = Output('psf_combined') def execute(self, namespace): ims = namespace[self.inputName] psf_raw = ims.data[:,:,:,:] # always normalize first, since needed for statistics psf_raw_norm = psf_raw.copy() # if self.normalize_intensity == True: psf_raw_norm /= psf_raw_norm.max(axis=(0,1,2), keepdims=True) psf_raw_norm /= psf_raw_norm.sum(axis=(0,1), keepdims=True) psf_raw_norm -= psf_raw_norm.min() psf_raw_norm /= psf_raw_norm.max() residual_max = np.abs(psf_raw_norm - psf_raw_norm.mean(axis=3, keepdims=True)).max(axis=(0,1,2)) print(residual_max) mask = residual_max < self.residual_threshold print "images ignore: {}".format(np.argwhere(~mask)[:,0]) print mask psf_raw_norm = psf_raw_norm[:,:,:,mask] print(psf_raw_norm.shape) psf_raw_norm -= psf_raw_norm.min() psf_raw_norm /= psf_raw_norm.max() psf_var = psf_raw_norm.var(axis=3) # psf_var_norm = psf_var / psf_combined.mean(axis=3) namespace[self.output_var_image] = ImageStack(psf_var, mdh=ims.mdh) # namespace[self.output_var_image_norm] = ImageStack(np.nan_to_num(psf_var_norm), mdh=ims.mdh) # if requested not to normalize, revert back to original data if not self.normalize_intensity: psf_raw_norm = psf_raw.copy()[:,:,:,mask] psf_combined = psf_raw_norm.mean(axis=3) psf_combined -= psf_combined.min() psf_combined /= psf_combined.max() # if self.smoothing_method == 'RBF': # dims = [np.arange(i) for i in psf_combined.shape] # elif self.smoothing_method == 'Gaussian' and if np.any(np.asarray(self.gaussian_filter)!=0): psf_processed = ndimage.gaussian_filter(psf_combined, self.gaussian_filter) else: psf_processed = psf_combined psf_processed -= psf_processed.min() psf_processed /= psf_processed.max() new_mdh = None try: new_mdh = MetaDataHandler.NestedClassMDHandler(ims.mdh) new_mdh["PSFExtraction.GaussianFilter"] = self.gaussian_filter new_mdh["PSFExtraction.NormalizeIntensity"] = self.normalize_intensity except Exception as e: print(e) namespace[self.output_images] = ImageStack(psf_processed, mdh=new_mdh) if True: fig, axes = pyplot.subplots(2, 3, figsize=(9,6)) axes[0,0].set_title('X') axes[0,0].plot(psf_raw_norm[:, psf_raw_norm.shape[1]//2, psf_raw_norm.shape[2]//2, :]) axes[1,0].plot(psf_combined[:, psf_combined.shape[1]//2, psf_combined.shape[2]//2], lw=1, color='red') axes[1,0].plot(psf_processed[:, psf_processed.shape[1]//2, psf_processed.shape[2]//2], lw=1, ls='--', color='black') axes[0,1].set_title('Y') axes[0,1].plot(psf_raw_norm[psf_raw_norm.shape[0]//2, :, psf_raw_norm.shape[2]//2, :]) axes[1,1].plot(psf_combined[psf_combined.shape[0]//2, :, psf_combined.shape[2]//2], lw=1, color='red') axes[1,1].plot(psf_processed[psf_processed.shape[0]//2, :, psf_processed.shape[2]//2], lw=1, ls='--', color='black') axes[0,2].set_title('Z') axes[0,2].plot(psf_raw_norm[psf_raw_norm.shape[0]//2, psf_raw_norm.shape[1]//2, :, :]) axes[1,2].plot(psf_combined[psf_combined.shape[0]//2, psf_combined.shape[1]//2, :], lw=1, color='red') axes[1,2].plot(psf_processed[psf_processed.shape[0]//2, psf_processed.shape[1]//2, :], lw=1, ls='--', color='black') fig.tight_layout() fig, ax = pyplot.subplots(1, 1, figsize=(4,3)) ax.hist(residual_max, bins=20) ax.axvline(self.residual_threshold, color='red', ls='--')
class LayerWrapper(HasTraits): cmap = Enum(*cm.cmapnames, default='gist_rainbow') clim = ListFloat([0, 1]) alpha = Float(1.0) visible = Bool(True) method = Enum(*ENGINES.keys()) engine = Instance(layers.BaseLayer) dsname = CStr('output') def __init__(self, pipeline, method='points', ds_name='', cmap='gist_rainbow', clim=[0, 1], alpha=1.0, visible=True, method_args={}): self._pipeline = pipeline #self._namespace=getattr(pipeline, 'namespace', {}) #self.dsname = None self.engine = None self.cmap = cmap self.clim = clim self.alpha = alpha self.visible = visible self.on_update = dispatch.Signal() self.on_trait_change(lambda: self.on_update.send(self), 'visible') self.on_trait_change(self.update, 'cmap, clim, alpha, dsname') self.on_trait_change(self._set_method, 'method') #self.set_datasource(ds_name) self.dsname = ds_name self._eng_params = dict(method_args) self.method = method self._pipeline.onRebuild.connect(self.update) @property def _namespace(self): return self._pipeline.layer_datasources @property def bbox(self): return self.engine.bbox @property def colour_map(self): return self.engine.colour_map @property def data_source_names(self): names = [] #''] for k, v in self._namespace.items(): names.append(k) if isinstance(v, tabular.ColourFilter): for c in v.getColourChans(): names.append('.'.join([k, c])) return names @property def datasource(self): if self.dsname == '': return self._pipeline parts = self.dsname.split('.') if len(parts) == 2: # special case - permit access to channels using dot notation # NB: only works if our underlying datasource is a ColourFilter ds, channel = parts return self._namespace.get(ds, None).get_channel_ds(channel) else: return self._namespace.get(self.dsname, None) def _set_method(self): if self.engine: self._eng_params = self.engine.get('point_size', 'vertexColour') #print(eng_params) self.engine = ENGINES[self.method](self._context) self.engine.set(**self._eng_params) self.engine.on_trait_change(self._update, 'vertexColour') self.engine.on_trait_change(self.update) self.update() # def set_datasource(self, ds_name): # self._dsname = ds_name # # self.update() def _update(self, *args, **kwargs): cdata = self._get_cdata() self.clim = [float(cdata.min()), float(cdata.max())] #self.update(*args, **kwargs) def update(self, *args, **kwargs): print('lw update') if not (self.engine is None or self.datasource is None): self.engine.update_from_datasource(self.datasource, getattr(cm, self.cmap), self.clim, self.alpha) self.on_update.send(self) def render(self, gl_canvas): if self.visible: self.engine.render(gl_canvas) def _get_cdata(self): try: cdata = self.datasource[self.engine.vertexColour] except KeyError: cdata = np.array([0, 1]) return cdata @property def default_view(self): from traitsui.api import View, Item, Group, InstanceEditor, EnumEditor from PYME.ui.custom_traits_editors import HistLimitsEditor, CBEditor return View( [ Group([ Item('dsname', label='Data', editor=EnumEditor(values=self.data_source_names)), ]), Item('method'), #Item('_'), Group([ Item('engine', style='custom', show_label=False, editor=InstanceEditor( view=self.engine.view(self.datasource.keys()))), ]), #Item('engine.color_key', editor=CBEditor(choices=self.datasource.keys())), Group([ Item('clim', editor=HistLimitsEditor(data=self._get_cdata), show_label=False), ]), Group([ Item('cmap', label='LUT'), Item('alpha'), Item('visible') ], orientation='horizontal', layout='flow') ], ) #buttons=['OK', 'Cancel']) def default_traits_view(self): return self.default_view
class TetrahedraRenderLayer(VertexRenderLayer): """ This program draws a WareFrame of the given points. They are interpreted as triangles. """ z_rescale = Float(1.0) size_cutoff = Float(1000.) internal_cull = Bool(True) wireframe = Bool(False) DRAW_MODE = GL_TRIANGLES def __init__(self, x=None, y=None, z=None, colors=None, color_map=None, size_cutoff=None, internal_cull=None, z_rescale=None, alpha=None, is_wire_frame=False): super(TetrahedraRenderLayer, self).__init__(colors=colors, color_map=color_map, alpha=alpha) if size_cutoff: self.size_cutoff = size_cutoff if internal_cull: self.internal_cull = internal_cull if z_rescale: self.z_rescale = z_rescale if x: p, a, n = gen3DTriangs(x, y, z / self.z_rescale, self.size_cutoff, internalCull=self.internal_cull) if colors == 'z': colors = p[:, 2] else: colors = 1. / a color_limit = [colors.min(), colors.max()] self.update_data(x, y, z, colors, cmap=cmap, clim=clim, alpha=alpha) else: pass #self.set_values(p, n) if is_wire_frame: self.set_shader_program(WireFrameShaderProgram) else: self.set_shader_program(GouraudShaderProgram) def update_from_datasource(self, ds, cmap=None, clim=None, alpha=1.0): x, y = ds[self.x_key], ds[self.y_key] if not self.z_key is None: z = ds[self.z_key] else: z = 0 * x p, a, n = gen3DTriangs(x, y, z / self.z_rescale, self.size_cutoff, internalCull=self.internal_cull) if False: #not self.vertexColour == '': #todo - set up for interpolated triangles c = ds[self.vertexColour] else: c = 1. / a self.update_data(p[:, 0], p[:, 1], p[:, 2], c, cmap=cmap, clim=clim, alpha=alpha) def update_data(self, x=None, y=None, z=None, colors=None, cmap=None, clim=None, alpha=1.0): self._vertices = None self._normals = None self._colors = None self._color_map = None self._color_limit = 0 self._alpha = 0 if x is not None and y is not None and z is not None: vertices = np.vstack((x.ravel(), y.ravel(), z.ravel())) vertices = vertices.T.ravel().reshape(len(x.ravel()), 3) normals = -0.69 * np.ones(vertices.shape) else: vertices = None normals = None if clim is not None and colors is not None and clim is not None: cs_ = ((colors - clim[0]) / (clim[1] - clim[0])) cs = cmap(cs_) cs[:, 3] = alpha cs = cs.ravel().reshape(len(colors), 4) else: #cs = None if not vertices is None: cs = np.ones((vertices.shape[0], 4), 'f') else: cs = None color_map = None color_limit = None self.set_values(vertices, normals, cs, cmap, clim, alpha) def render(self, gl_canvas): """ Parameters ---------- gl_canvas nothing of the canvas is used. That's how it should be. Returns ------- """ with self.shader_program: n_vertices = self.get_vertices().shape[0] glVertexPointerf(self.get_vertices()) glNormalPointerf(self.get_normals()) glColorPointerf(self.get_colors()) glPushMatrix() glDrawArrays(self.DRAW_MODE, 0, n_vertices) glPopMatrix()