class PSFSettings(HasTraits): wavelength_nm = Float(700.) NA = Float(1.47) vectorial = Bool(False) zernike_modes = Dict() zernike_modes_lower = Dict() phases = List([0, .5, 1, 1.5]) four_pi = Bool(False) def default_traits_view(self): from traitsui.api import View, Item #from PYME.ui.custom_traits_editors import CBEditor return View(Item(name='wavelength_nm'), Item(name='NA'), Item(name='vectorial'), Item(name='four_pi', label='4Pi'), Item(name='zernike_modes'), Item(name='zernike_modes_lower', visible_when='four_pi==True'), Item(name='phases', visible_when='four_pi==True', label='phases/pi'), resizable=True, buttons=['OK'])
class HistByID(ModuleBase): """Plot histogram of a column by ID""" inputName = Input('measurements') IDkey = CStr('objectID') histkey = CStr('qIndex') outputName = Output('outGraph') nbins = Int(50) minval = Float(float('nan')) maxval = Float(float('nan')) def execute(self, namespace): import math meas = namespace[self.inputName] ids = meas[self.IDkey] uid, valsu = uniqueByID(ids, meas[self.histkey]) if math.isnan(self.minval): minv = valsu.min() else: minv = self.minval if math.isnan(self.maxval): maxv = valsu.max() else: maxv = self.maxval import matplotlib.pyplot as plt plt.figure() plt.hist(valsu, self.nbins, range=(minv, maxv)) plt.xlabel(self.histkey) @property def _key_choices(self): #try and find the available column names try: return sorted(self._parent.namespace[self.inputName].keys()) except: return [] @property def default_view(self): from traitsui.api import View, Group, Item from PYME.ui.custom_traits_editors import CBEditor return View(Item('inputName', editor=CBEditor(choices=self._namespace_keys)), Item('_'), Item('IDkey', editor=CBEditor(choices=self._key_choices)), Item('histkey', editor=CBEditor(choices=self._key_choices)), Item('nbins'), Item('minval'), Item('maxval'), Item('_'), Item('outputName'), buttons=['OK'])
class WormlikeSource(PointSource): kbp = Float(200) steplength = Float(1.0) lengthPerKbp = Float(10.0) persistLength = Float(150.0) #name = Str('Wormlike Chain') def getPoints(self): from PYME.simulation import wormlike2 wc = wormlike2.wormlikeChain(self.kbp, self.steplength, self.lengthPerKbp, self.persistLength) return wc.xp, wc.yp, wc.zp
class Scale(Filter): """Scale an image intensities by a constant""" scale = Float(1) def applyFilter(self, data, chanNum, i, image0): return self.scale * data
class NormalizeMean(Filter): """Normalize an image so that the mean is 1""" offset = Float(0) def applyFilter(self, data, chanNum, i, image0): data = data - self.offset return data / float(data.mean())
class QindexScale(ModuleBase): inputName = Input('qindex') outputName = Output('qindex-calibrated') qIndexkey = CStr('qIndex') qindexValue = Float(1.0) NEquivalent = Float(1.0) def execute(self, namespace): inp = namespace[self.inputName] mapped = tabular.mappingFilter(inp) qkey = self.qIndexkey scaled = inp[qkey] qigood = inp[qkey] > 0 scaled[ qigood] = inp[qkey][qigood] * self.NEquivalent / self.qindexValue self.newKey = '%sCal' % qkey mapped.addColumn(self.newKey, scaled) namespace[self.outputName] = mapped @property def _key_choices(self): #try and find the available column names try: return sorted(self._parent.namespace[self.inputName].keys()) except: return [] @property def default_view(self): from traitsui.api import View, Group, Item from PYME.ui.custom_traits_editors import CBEditor return View(Item('inputName', editor=CBEditor(choices=self._namespace_keys)), Item('_'), Item('qIndexkey', editor=CBEditor(choices=self._key_choices)), Item('qindexValue'), Item('NEquivalent'), Item('_'), Item('outputName'), buttons=['OK'])
class RuleTile(HasTraits): task_timeout = Float(60 * 10) rule_timeout = Float(60 * 10) def get_params(self): editable = self.class_editable_traits() return editable @property def default_view(self): if wx.GetApp() is None: return None from traitsui.api import View, Item return View([Item(tn) for tn in self.get_params()], buttons=['OK']) def default_traits_view(self): """ This is the traits stock method to specify the default view""" return self.default_view
class WormlikeSource(PointSource): kbp = Float(200) steplength = Float(1.0) lengthPerKbp = Float(10.0) persistLength = Float(150.0) #name = Str('Wormlike Chain') def getPoints(self): from PYME.simulation import wormlike2 wc = wormlike2.wormlikeChain(self.kbp, self.steplength, self.lengthPerKbp, self.persistLength) return wc.xp, wc.yp, wc.zp def genMetaData(self, mdh): mdh['GeneratedPoints.Source.Type'] = 'Wormlike' mdh['GeneratedPoints.Source.Kbp'] = self.kbp mdh['GeneratedPoints.Source.StepLength'] = self.steplength mdh['GeneratedPoints.Source.LengthPerKbp'] = self.lengthPerKbp mdh['GeneratedPoints.Source.PersistLength'] = self.persistLength
class LabelByRegionProperty(Filter): """Asigns a region property to each contiguous region in the input mask. Optionally throws away all regions for which property is outside a given range. """ regionProperty = Enum(['area', 'circularity', 'aspectratio']) filterByProperty = Bool(False) propertyMin = Float(0) propertyMax = Float(1e6) def applyFilter(self, data, chanNum, frNum, im): mask = data > 0.5 labs, nlabs = ndimage.label(mask) rp = skimage.measure.regionprops(labs, None, cache=True) m2 = np.zeros_like(mask, dtype='float') objs = ndimage.find_objects(labs) for region in rp: oslices = objs[region.label - 1] r = labs[oslices] == region.label #print r.shape if self.regionProperty == 'area': propValue = region.area elif self.regionProperty == 'aspectratio': propValue = region.major_axis_length / region.minor_axis_length elif self.regionProperty == 'circularity': propValue = 4 * math.pi * region.area / (region.perimeter * region.perimeter) if self.filterByProperty: if (propValue >= self.propertyMin) and (propValue <= self.propertyMax): m2[oslices] += r * propValue else: m2[oslices] += r * propValue return m2 def completeMetadata(self, im): im.mdh['Labelling.Property'] = self.regionProperty
class FlexiThreshold(Filter): """Chose a threshold using a range of available thresholding methods. Currently we can chose from: simple, fractional, otsu, isodata """ method = Enum( 'simple', 'fractional', 'otsu', 'isodata', 'li', 'yen') # newer skimage has minimum, mean and triangle as well parameter = Float(0.5) clipAt = Float( 2e6 ) # used to be 10 - increase to large value for newer PYME renderings def fractionalThreshold(self, data): N, bins = np.histogram(data, bins=5000) #calculate bin centres bin_mids = (bins[:-1]) cN = np.cumsum(N * bin_mids) i = np.argmin(abs(cN - cN[-1] * (1 - self.parameter))) threshold = bins[i] return threshold def applyFilter(self, data, chanNum, frNum, im): if self.method == 'fractional': threshold = self.fractionalThreshold( np.clip(data, None, self.clipAt)) elif self.method == 'simple': threshold = self.parameter else: method = getattr(skf, 'threshold_%s' % self.method) threshold = method(np.clip(data, None, self.clipAt)) mask = data > threshold return mask def completeMetadata(self, im): im.mdh['Processing.ThresholdParameter'] = self.parameter im.mdh['Processing.ThresholdMethod'] = self.method
class InterpolateDrift(ModuleBase): """ Creates a spline interpolator from drift data. (``scipy.interpolate.UnivariateSpline``) Inputs ------ input_drift_raw : Tuple of arrays Drift measured from localization or image dataset. Outputs ------- output_drift_interpolator : Drift interpolator. Returns drift when called with frame number / time. output_drift_plot : Plot Plot of the original and interpolated drift. Parameters ---------- degree_of_spline : int Degree of the smoothing spline. smoothing_factor : float Smoothing factor. """ # input_dummy = Input('input') # breaks GUI without this??? # load_path = File() degree_of_spline = Int(3) # 1 for linear, 3 for cubic smoothing_factor = Float( -1) # 0 for no smoothing. set to negative for UnivariateSpline defulat input_drift_raw = Input('drift_raw') output_drift_interpolator = Output('drift_interpolator') output_drift_plot = Output('drift_plot') def execute(self, namespace): # data = np.load(self.load_path) # tIndex = data['tIndex'] # drift = data['drift'] # namespace[self.output_drift_raw] = (tIndex, drift) tIndex, drift = namespace[self.input_drift_raw] spl = interpolate_drift(tIndex, drift, self.degree_of_spline, self.smoothing_factor) namespace[self.output_drift_interpolator] = spl # # non essential, only for plotting out drift data namespace[self.output_drift_plot] = Plot( partial(generate_drift_plot, tIndex, drift, spl)) namespace[self.output_drift_plot].plot()
class IsClose(ArithmaticFilter): """ Wrapper for numpy.isclose Parameters: ----------- abs_tolerance: Float absolute tolerance rel_tolerance: Float relative tolerance Notes: ------ from numpy docs, tolerances are combined as: absolute(a - b) <= (atol + rtol * absolute(b)) """ abs_tolerance = Float(1e-8) rel_tolerance = Float(1e-5) def applyFilter(self, data0, data1, chanNum, i, image0): return np.isclose(data0, data1, atol=self.abs_tolerance, rtol=self.rel_tolerance)
class TimedSpecies(ModuleBase): inputName = Input('filtered') outputName = Output('timedSpecies') Species_1_Name = CStr('Species1') Species_1_Start = Float(0) Species_1_Stop = Float(1e6) Species_2_Name = CStr('') Species_2_Start = Float(0) Species_2_Stop = Float(0) Species_3_Name = CStr('') Species_3_Start = Float(0) Species_3_Stop = Float(0) def execute(self, namespace): inp = namespace[self.inputName] mapped = tabular.mappingFilter(inp) timedSpecies = self.populateTimedSpecies() mapped.addColumn('ColourNorm', np.ones_like(mapped['t'], 'float')) for species in timedSpecies: mapped.addColumn('p_%s' % species['name'], (mapped['t'] >= species['t_start']) * (mapped['t'] < species['t_end'])) if 'mdh' in dir(inp): mapped.mdh = inp.mdh mapped.mdh['TimedSpecies'] = timedSpecies namespace[self.outputName] = mapped def populateTimedSpecies(self): ts = [] if self.Species_1_Name: ts.append({ 'name': self.Species_1_Name, 't_start': self.Species_1_Start, 't_end': self.Species_1_Stop }) if self.Species_2_Name: ts.append({ 'name': self.Species_2_Name, 't_start': self.Species_2_Start, 't_end': self.Species_2_Stop }) if self.Species_3_Name: ts.append({ 'name': self.Species_3_Name, 't_start': self.Species_3_Start, 't_end': self.Species_3_Stop }) return ts
class PointFeaturesPairwiseDist(PointFeatureBase): """ Create a feature vector for each point in a point-cloud using a histogram of it's distances to all other points """ inputLocalisations = Input('localisations') outputName = Output('features') binWidth = Float(100.) # width of the bins in nm numBins = Int(20) #number of bins (starting at 0) threeD = Bool(True) normaliseRelativeDensity = Bool( False ) # divide by the sum of all radial bins. If not performed, the first principle component will likely be average density def execute(self, namespace): from PYME.Analysis.points import DistHist points = namespace[self.inputLocalisations] if self.threeD: x, y, z = points['x'], points['y'], points['z'] f = np.array([ DistHist.distanceHistogram3D(x[i], y[i], z[i], x, y, z, self.numBins, self.binWidth) for i in xrange(len(x)) ]) else: x, y = points['x'], points['y'] f = np.array([ DistHist.distanceHistogram(x[i], y[i], x, y, self.numBins, self.binWidth) for i in xrange(len(x)) ]) namespace[self.outputName] = self._process_features(points, f)
class InterpolatePSF(ModuleBase): """ Interpolate PSF with RBF. Very stupid. Very slow. Performed on local pixels and combine by tiling. Only uses the first color channel """ inputName = Input('input') rbf_radius = Float(250.0) target_voxelsize = List(Float, [100., 100., 100.]) output_images = Output('psf_interpolated') def execute(self, namespace): ims = namespace[self.inputName] data = ims.data[:,:,:,0] dims_original = list() voxelsize = [ims.mdh.voxelsize.x, ims.mdh.voxelsize.y, ims.mdh.voxelsize.z] for dim, dim_len in enumerate(data.shape): d = np.linspace(0, dim_len-1, dim_len) * voxelsize[dim] * 1E3 d -= d.mean() dims_original.append(d) X, Y, Z = np.meshgrid(*dims_original, indexing='ij') dims_interpolated = list() for dim, dim_len in enumerate(data.shape): tar_len = int(np.ceil((voxelsize[dim]*1E3 * dim_len) / self.target_voxelsize[dim])) d = np.arange(tar_len) * self.target_voxelsize[dim] d -= d.mean() dims_interpolated.append(d) X_interp, Y_interp, Z_interp = np.meshgrid(*dims_interpolated, indexing='ij') pts_interp = zip(*[X_interp.flatten(), Y_interp.flatten(), Z_interp.flatten()]) results = np.zeros(X_interp.size) for i, pt in enumerate(pts_interp): results[i] = self.InterpolateAt(pt, X, Y, Z, data[:,:,:]) if i % 100 == 0: print("{} out of {} completed.".format(i, results.shape[0])) results = results.reshape(len(dims_interpolated[0]), len(dims_interpolated[1]), len(dims_interpolated[2])) # return results new_mdh = None try: new_mdh = MetaDataHandler.NestedClassMDHandler(ims.mdh) new_mdh["voxelsize.x"] = self.target_voxelsize[0] * 1E-3 new_mdh["voxelsize.y"] = self.target_voxelsize[1] * 1E-3 new_mdh["voxelsize.z"] = self.target_voxelsize[2] * 1E-3 new_mdh['Interpolation.Method'] = 'RBF' new_mdh['Interpolation.RbfRadius'] = self.rbf_radius except Exception as e: print(e) namespace[self.output_images] = ImageStack(data=results, mdh=new_mdh) def InterpolateAt(self, pt, X, Y, Z, data, radius=250.): X_subset, Y_subset, Z_subset, data_subset = self.GetPointsInNeighbourhood(pt, X, Y, Z, data, radius) rbf = interpolate.Rbf(X_subset, Y_subset, Z_subset, data_subset, function="cubic", smooth=1E3)#, norm=euclidean_norm_numpy) # print pt return rbf(*pt) def GetPointsInNeighbourhood(self, center, X, Y, Z, data, radius=250.): distance = np.sqrt((X-center[0])**2 + (Y-center[1])**2 + (Z-center[2])**2) # print distance.shape mask = distance < 250. return X[mask], Y[mask], Z[mask], data[mask]
class Pow(Filter): "Raise an image to a given power (can be fractional for sqrt)" power = Float(2) def applyFilter(self, data, chanNum, i, image0): return np.power(data, self.power)
class EnsembleFitROIs( ModuleBase): # Note that this should probably be moved somewhere else inputName = Input('ROIs') fit_type = CStr('LorentzianConvolvedSolidSphere_ensemblePSF') ensemble_parameter_guess = Float(50.) hold_ensemble_parameter_constant = Bool(False) outputName = Output('fit_results') def execute(self, namespace): inp = namespace[self.inputName] # generate RegionHandler from tables handler = RegionHandler() handler._load_from_list(inp) fit_class = region_fitters.fitters[self.fit_type] fitter = fit_class(handler) if self.hold_ensemble_parameter_constant: fitter.fit_profiles(self.ensemble_parameter_guess) else: fitter.ensemble_fit(self.ensemble_parameter_guess) res = tabular.RecArraySource(fitter.results) # propagate metadata, if present res.mdh = MetaDataHandler.NestedClassMDHandler( getattr(inp, 'mdh', None)) res.mdh['EnsembleFitROIs.FitType'] = self.fit_type res.mdh[ 'EnsembleFitROIs.EnsembleParameterGuess'] = self.ensemble_parameter_guess res.mdh[ 'EnsembleFitROIs.HoldEnsembleParamConstant'] = self.hold_ensemble_parameter_constant namespace[self.outputName] = res @property def _fitter_choices(self): return list(region_fitters.ensemble_fitters.keys()) #FIXME??? @property def default_view(self): from traitsui.api import View, Item from PYME.ui.custom_traits_editors import CBEditor return View(Item('inputName', editor=CBEditor(choices=self._namespace_keys)), Item('_'), Item('fit_type', editor=CBEditor(choices=self._fitter_choices)), Item('_'), Item('ensemble_parameter_guess'), Item('_'), Item('hold_ensemble_parameter_constant'), Item('_'), Item('outputName'), buttons=['OK']) @property def pipeline_view(self): from traitsui.api import View, Item from PYME.ui.custom_traits_editors import CBEditor return View( Item('fit_type', editor=CBEditor(choices=self._fitter_choices)), Item('_'), Item('ensemble_parameter_guess'), Item('_'), Item('hold_ensemble_parameter_constant'), ) @property def dsview_view(self): from traitsui.api import View, Item from PYME.ui.custom_traits_editors import CBEditor return View(Item('fit_type', editor=CBEditor(choices=self._fitter_choices)), Item('_'), Item('ensemble_parameter_guess'), Item('_'), Item('hold_ensemble_parameter_constant'), buttons=['OK'])
class OctreeRenderLayer(TriangleRenderLayer): """ Layer for viewing octrees. Takes in an octree, splits the faces into triangles, and then uses the rendering engines from PYME.LMVis.layers.triangle_mesh. """ # Additional properties (the rest are inherited from TriangleRenderLayer) depth = Int(3, desc='Depth at which to render Octree. Set to -1 for dynamic depth rendering.') density = Float(0.0, desc='Minimum density of octree node to display.') min_points = Int(10, desc='Number of points/node to truncate octree at') def __init__(self, pipeline, method='wireframe', dsname='', context=None, **kwargs): TriangleRenderLayer.__init__(self, pipeline, method, dsname, context, **kwargs) self.on_trait_change(self.update, 'depth') self.on_trait_change(self.update, 'density') self.on_trait_change(self.update, 'min_points') @property def _ds_class(self): from PYME.experimental import octree return (octree.Octree, octree.PyOctree) def update_from_datasource(self, ds): """ Opens an octree. Subdivides the faces into triangles. Feeds the triangle points/normals to update_data. Parameters ---------- ds : Octree (see PYME.experimental.octree) Returns ------- None """ nodes = ds._nodes[ds._nodes[ds._nodes['parent']]['nPoints'] >= float(self.min_points)] if self.depth > 0: # Grab the nodes at the specified depth nodes = nodes[nodes['depth'] == self.depth] box_sizes = np.ones((nodes.shape[0], 3))*ds.box_size(self.depth) node_density = 1.*nodes['nPoints']/np.prod(box_sizes,axis=1) nodes = nodes[node_density >= self.density] box_sizes = np.ones((nodes.shape[0], 3))*ds.box_size(self.depth) alpha = nodes['nPoints']/box_sizes[:,0] elif self.depth == 0: # plot all bins nodes = nodes[nodes['nPoints'] >= 1] box_sizes = np.vstack(ds.box_size(nodes['depth'])).T alpha = nodes['nPoints'] * ((2 ** nodes['depth'])**3) else: # Plot leaf nodes nodes = nodes[(np.sum(nodes['children'],axis=1) == 0)&(nodes['depth'] > 0)] box_sizes = np.vstack(ds.box_size(nodes['depth'])).T alpha = nodes['nPoints']*((2.0**nodes['depth']))**3 if len(nodes) > 0: c = nodes['centre'] # center shifts = (box_sizes[:,None]*OCT_SHIFT[None,:])*0.5 v = (c[:,None,:] + shifts) # # z # ^ # | # v4 ----------v6 # /| /| # / | / | # v5----------v7 | # | | c | | # | v0---------|-v2 # | / | / # v1-----------v3---> y # / # x # # Now note that the counterclockwise triangles (when viewed straight-on) formed along the faces of the cube are: # # v0 v2 v1 # v0 v1 v5 # v0 v5 v4 # v0 v6 v2 # v0 v4 v6 # v1 v2 v3 # v1 v3 v7 # v1 v7 v5 # v2 v6 v7 # v2 v7 v3 # v4 v5 v6 # v5 v7 v6 # Counterclockwise triangles (when viewed straight-on) formed along # the faces of an octree box t0 = np.vstack(v[:,[0,0,0,0,0,1,1,1,2,2,4,5],:]) t1 = np.vstack(v[:,[2,1,5,6,4,2,3,7,6,7,5,7],:]) t2 = np.vstack(v[:,[1,5,4,2,6,3,7,5,7,3,6,6],:]) x, y, z = np.hstack([t0,t1,t2]).reshape(-1, 3).T # positions # Now we create the normal as the cross product tn = np.cross((t2-t1),(t0-t1)) # We copy the normals 3 times per triangle to get 3x(3N) normals to match the vertices shape xn, yn, zn = np.repeat(tn.T, 3, axis=1) # normals # Color is fixed constnat for octree c = np.ones(len(x)) clim = [0, 1] alpha = self.alpha*alpha/alpha.max() alpha = (alpha[None,:]*np.ones(12)[:,None]) alpha = np.repeat(alpha.ravel(), 3) print('Octree scaled alpha range: %g, %g' % (alpha.min(), alpha.max())) cmap = getattr(cm, self.cmap) # Do we have coordinates? Concatenate into vertices. if x is not None and y is not None and z is not None: vertices = np.vstack((x.ravel(), y.ravel(), z.ravel())) self._vertices = vertices.T.ravel().reshape(len(x.ravel()), 3) if not xn is None: self._normals = np.vstack((xn.ravel(), yn.ravel(), zn.ravel())).T.ravel().reshape(len(x.ravel()), 3) else: self._normals = -0.69 * np.ones(self._vertices.shape) self._bbox = np.array([x.min(), y.min(), z.min(), x.max(), y.max(), z.max()]) else: self._bbox = None if clim is not None and c is not None and cmap is not None: cs_ = ((c - clim[0]) / (clim[1] - clim[0])) cs = cmap(cs_) if self.method in ['flat', 'tessel']: alpha = cs_ * alpha cs[:, 3] = alpha if self.method == 'tessel': cs = np.power(cs, 0.333) self._colors = cs.ravel().reshape(len(c), 4) else: # cs = None if not self._vertices is None: self._colors = np.ones((self._vertices.shape[0], 4), 'f') print('Colors: {}'.format(self._colors)) self._alpha = alpha self._color_map = cmap self._color_limit = clim else: print('No nodes for density {0}, depth {1}'.format(self.density, self.depth)) @property def default_view(self): from traitsui.api import View, Item, Group, InstanceEditor, EnumEditor from PYME.ui.custom_traits_editors import HistLimitsEditor, CBEditor return View([#Group([ Item('dsname', label='Data', editor=EnumEditor(name='_datasource_choices')), Item('method'), Item('depth'),Item('min_points'), #Item('vertexColour', editor=EnumEditor(name='_datasource_keys'), label='Colour'), #Group([Item('clim', editor=HistLimitsEditor(data=self._get_cdata), show_label=False), ]), Group([Item('cmap', label='LUT'), Item('alpha')])], )
class ImageSource(PointSource): image = WRDictEnum(image.openImages) points_per_pixel = Float(0.1) #name = Str('Density Image') #foo = Enum([1,2,3,4]) helpInfo = { 'points_per_pixel': ''' Select average number of points (dye molecules or docking sites) per pixel in image regions where the density values are 1. The number is a floating point fraction, e.g. 0.1, and shouldn't exceed 1. It is used for Monte-Carlo rejection of positions and larger values (>~0.2) will result in images which have visible pixel-grid structure because the Monte-Carlo sampling is no longer a good approximation to random sampling over the grid. If this is a problem for your application / you can't get high enough density without a high acceptance fraction, use an up-sampled source image with a smaller pixel size. ''', 'image': ''' Select an image from the list of open images. Note that you need to open or generate the source image you want to use so that this list is not empty. The image will be normalised for the purpose of the simulation, with its maximum set to 1. It describes the density of markers in the simulated sample, where values of 1 have a density of markers as given by the `points per pixel` parameter, i.e. in the Monte-Carlo sampling the acceptance probability = image*points_per_pixel. Smaller density values therefore give rise to proportionally fewer markers per pixel. ''', } def helpStr(self, name): def cleanupHelpStr(str): return str.strip().replace('\n', ' ').replace('\r', '') return cleanupHelpStr(self.helpInfo[name]) def default_traits_view(self): from traitsui.api import View, Item traits_view = View( Item('points_per_pixel', help=self.helpStr('points_per_pixel'), tooltip='mean number of marker points per pixel'), Item('image', help=self.helpStr('image'), tooltip= 'select the marker density image from the list of open images' ), buttons=['OK', 'Help']) return traits_view def getPoints(self): from PYME.simulation import locify # print((self.image)) # if still needed should be replaced by a logging statement try: im = image.openImages[self.image] except KeyError: # no image of that name: # If uncaught this will pop up in the error dialog from 'Computation in progress', so shouldn't need # an explicit dialog / explicit handing. TODO - do we need an error subclass - e.g. UserError or ParameterError # which the error dialog treats differently to more generic errors so as to make it clear that it's something # the user has done wrong rather than a bug??? raise UserError( 'No open image found with name: "%s", please set "image" property of ImageSource to a valid image name\nThis must be an image which is already open.\n\n' % self.image) #import numpy as np d = im.data[:, :, 0, 0].astype('f') #normalise the image d = d / d.max() return locify.locify(d, pixelSize=im.pixelSize, pointsPerPixel=self.points_per_pixel) def get_bounds(self): return image.openImages[self.image].imgBounds def refresh_choices(self): ed = self.trait('image').editor if ed: try: ed._values_changed() except TypeError: # TODO - why can _values_changed be None?? # is there a better way to handle/avoid this? pass #super( HasTraits, self ).configure_traits(*args, **kwargs) # traits_view = View( Item('points_per_pixel'), # Item('image'), ## Item( 'image', ## label='Image', ## editor = ## EnumEditor(values={'foo':0, 'bar' : 1}),#image.openImages), ## ), # buttons = ['OK']) def genMetaData(self, mdh): mdh['GeneratedPoints.Source.Type'] = 'Image' mdh['GeneratedPoints.Source.PointsPerPixel'] = self.points_per_pixel mdh['GeneratedPoints.Source.Image'] = self.image
class TriangleRenderLayer(EngineLayer): """ Layer for viewing triangle meshes. """ # properties to show in the GUI. Note that we also inherit 'visible' from BaseLayer vertexColour = CStr('constant', desc='Name of variable used to colour our points') cmap = Enum(*cm.cmapnames, default='gist_rainbow', desc='Name of colourmap used to colour faces') clim = ListFloat([0, 1], desc='How our variable should be scaled prior to colour mapping') alpha = Float(1.0, desc='Face tranparency') method = Enum(*ENGINES.keys(), desc='Method used to display faces') normal_mode = Enum(['Per vertex', 'Per face']) dsname = CStr('output', desc='Name of the datasource within the pipeline to use as a source of triangles (should be a TriangularMesh object)') _datasource_choices = List() _datasource_keys = List() def __init__(self, pipeline, method='wireframe', dsname='', context=None, **kwargs): EngineLayer.__init__(self, context=context, **kwargs) self._pipeline = pipeline self.engine = None self.cmap = 'gist_rainbow' self.x_key = 'x' # TODO - make these traits? self.y_key = 'y' self.z_key = 'z' self.xn_key = 'xn' self.yn_key = 'yn' self.zn_key = 'zn' self._bbox = None # define a signal so that people can be notified when we are updated (currently used to force a redraw when # parameters change) self.on_update = dispatch.Signal() # define responses to changes in various traits self.on_trait_change(self._update, 'vertexColour') self.on_trait_change(lambda: self.on_update.send(self), 'visible') self.on_trait_change(self.update, 'cmap, clim, alpha, dsname, normal_mode') self.on_trait_change(self._set_method, 'method') # update any of our traits which were passed as command line arguments self.set(**kwargs) # update datasource and method self.dsname = dsname if self.method == method: #make sure we still call _set_method even if we start with the default method self._set_method() else: self.method = method # if we were given a pipeline, connect ourselves to the onRebuild signal so that we can automatically update # ourselves if not self._pipeline is None: self._pipeline.onRebuild.connect(self.update) @property def datasource(self): """ Return the datasource we are connected to (does not go through the pipeline for triangles_mesh). """ return self._pipeline.get_layer_data(self.dsname) #return self.datasource @property def _ds_class(self): # from PYME.experimental import triangle_mesh from PYME.experimental import _triangle_mesh as triangle_mesh return triangle_mesh.TrianglesBase def _set_method(self): self.engine = ENGINES[self.method](self._context) self.update() def _get_cdata(self): try: cdata = self.datasource[self.vertexColour] except (KeyError, TypeError): cdata = np.array([0, 1]) return cdata def _update(self, *args, **kwargs): #pass cdata = self._get_cdata() self.clim = [float(cdata.min()), float(cdata.max())] self.update(*args, **kwargs) def update(self, *args, **kwargs): self._datasource_choices = [k for k, v in self._pipeline.dataSources.items() if isinstance(v, self._ds_class)] if not self.datasource is None: dks = ['constant',] if hasattr(self.datasource, 'keys'): dks = dks + sorted(self.datasource.keys()) self._datasource_keys = dks if not (self.engine is None or self.datasource is None): print('lw update') self.update_from_datasource(self.datasource) self.on_update.send(self) @property def bbox(self): return self._bbox def update_from_datasource(self, ds): """ Pulls vertices/normals from a binary STL file. See PYME.IO.FileUtils.stl for more info. Calls update_data on the input. Parameters ---------- ds : PYME.experimental.triangular_mesh.TriangularMesh object Returns ------- None """ #t = ds.vertices[ds.faces] #n = ds.vertex_normals[ds.faces] x, y, z = ds.vertices[ds.faces].reshape(-1, 3).T if self.normal_mode == 'Per vertex': xn, yn, zn = ds.vertex_normals[ds.faces].reshape(-1, 3).T else: xn, yn, zn = np.repeat(ds.face_normals.T, 3, axis=1) if self.vertexColour in ['', 'constant']: c = np.ones(len(x)) clim = [0, 1] #elif self.vertexColour == 'vertex_index': # c = np.arange(0, len(x)) else: c = ds[self.vertexColour][ds.faces].ravel() clim = self.clim cmap = getattr(cm, self.cmap) alpha = float(self.alpha) # Do we have coordinates? Concatenate into vertices. if x is not None and y is not None and z is not None: vertices = np.vstack((x.ravel(), y.ravel(), z.ravel())) self._vertices = vertices.T.ravel().reshape(len(x.ravel()), 3) if not xn is None: self._normals = np.vstack((xn.ravel(), yn.ravel(), zn.ravel())).T.ravel().reshape(len(x.ravel()), 3) else: self._normals = -0.69 * np.ones(self._vertices.shape) self._bbox = np.array([x.min(), y.min(), z.min(), x.max(), y.max(), z.max()]) else: self._bbox = None # TODO: This temporarily sets all triangles to the color red. User should be able to select color. if c is None: c = np.ones(self._vertices.shape[0]) * 255 # vector of pink if clim is not None and c is not None and cmap is not None: cs_ = ((c - clim[0]) / (clim[1] - clim[0])) cs = cmap(cs_) if self.method in ['flat', 'tessel']: alpha = cs_ * alpha cs[:, 3] = alpha if self.method == 'tessel': cs = np.power(cs, 0.333) self._colors = cs.ravel().reshape(len(c), 4) else: # cs = None if not self._vertices is None: self._colors = np.ones((self._vertices.shape[0], 4), 'f') self._alpha = alpha self._color_map = cmap self._color_limit = clim def get_vertices(self): return self._vertices def get_normals(self): return self._normals def get_colors(self): return self._colors def get_color_map(self): return self._color_map @property def colour_map(self): return self._color_map def get_color_limit(self): return self._color_limit @property def default_view(self): from traitsui.api import View, Item, Group, InstanceEditor, EnumEditor from PYME.ui.custom_traits_editors import HistLimitsEditor, CBEditor return View([Group([Item('dsname', label='Data', editor=EnumEditor(name='_datasource_choices')), ]), Item('method'), Item('normal_mode', visible_when='method=="shaded"'), Item('vertexColour', editor=EnumEditor(name='_datasource_keys'), label='Colour'), Group([Item('clim', editor=HistLimitsEditor(data=self._get_cdata), show_label=False), ], visible_when='vertexColour != "constant"'), Group([Item('cmap', label='LUT'), Item('alpha', visible_when='method in ["flat", "tessel"]') ]) ], ) # buttons=['OK', 'Cancel']) def default_traits_view(self): return self.default_view
class PointCloudRenderLayer(EngineLayer): """ A layer for viewing point-cloud data, using one of 3 engines (indicated above) """ # properties to show in the GUI. Note that we also inherit 'visible' from BaseLayer vertexColour = CStr('', desc='Name of variable used to colour our points') point_size = Float(30.0, desc='Rendered size of the points in nm') cmap = Enum(*cm.cmapnames, default='gist_rainbow', desc='Name of colourmap used to colour points') clim = ListFloat( [0, 1], desc='How our variable should be scaled prior to colour mapping') alpha = Float(1.0, desc='Point tranparency') method = Enum(*ENGINES.keys(), desc='Method used to display points') dsname = CStr( 'output', desc= 'Name of the datasource within the pipeline to use as a source of points' ) _datasource_keys = List() _datasource_choices = List() def __init__(self, pipeline, method='points', dsname='', context=None, **kwargs): EngineLayer.__init__(self, context=context, **kwargs) self._pipeline = pipeline self.engine = None self.cmap = 'gist_rainbow' self.x_key = 'x' #TODO - make these traits? self.y_key = 'y' self.z_key = 'z' self.xn_key = 'xn' self.yn_key = 'yn' self.zn_key = 'zn' self._bbox = None # define a signal so that people can be notified when we are updated (currently used to force a redraw when # parameters change) self.on_update = dispatch.Signal() # define responses to changes in various traits self.on_trait_change(self._update, 'vertexColour') self.on_trait_change(lambda: self.on_update.send(self), 'visible') self.on_trait_change(self.update, 'cmap, clim, alpha, dsname, point_size') self.on_trait_change(self._set_method, 'method') # update any of our traits which were passed as command line arguments self.set(**kwargs) # update datasource name and method #logger.debug('Setting dsname and method') self.dsname = dsname self.method = method self._set_method() # if we were given a pipeline, connect ourselves to the onRebuild signal so that we can automatically update # ourselves if not self._pipeline is None: self._pipeline.onRebuild.connect(self.update) @property def datasource(self): """ Return the datasource we are connected to (through our dsname property). """ return self._pipeline.get_layer_data(self.dsname) def _set_method(self): #logger.debug('Setting layer method to %s' % self.method) self.engine = ENGINES[self.method](self._context) self.update() def _get_cdata(self): try: cdata = self.datasource[self.vertexColour] except KeyError: cdata = np.array([0, 1]) return cdata def _update(self, *args, **kwargs): cdata = self._get_cdata() self.clim = [float(cdata.min()), float(cdata.max())] #self.update(*args, **kwargs) def update(self, *args, **kwargs): print('lw update') self._datasource_choices = self._pipeline.layer_data_source_names if not self.datasource is None: self._datasource_keys = sorted(self.datasource.keys()) if not (self.engine is None or self.datasource is None): self.update_from_datasource(self.datasource) self.on_update.send(self) @property def bbox(self): return self._bbox def update_from_datasource(self, ds): x, y = ds[self.x_key], ds[self.y_key] if not self.z_key is None: try: z = ds[self.z_key] except KeyError: z = 0 * x else: z = 0 * x if not self.vertexColour == '': c = ds[self.vertexColour] else: c = 0 * x if self.xn_key in ds.keys(): xn, yn, zn = ds[self.xn_key], ds[self.yn_key], ds[self.zn_key] self.update_data(x, y, z, c, cmap=getattr(cm, self.cmap), clim=self.clim, alpha=self.alpha, xn=xn, yn=yn, zn=zn) else: self.update_data(x, y, z, c, cmap=getattr(cm, self.cmap), clim=self.clim, alpha=self.alpha) def update_data(self, x=None, y=None, z=None, colors=None, cmap=None, clim=None, alpha=1.0, xn=None, yn=None, zn=None): self._vertices = None self._normals = None self._colors = None self._color_map = None self._color_limit = 0 self._alpha = 0 if x is not None and y is not None and z is not None: vertices = np.vstack((x.ravel(), y.ravel(), z.ravel())) vertices = vertices.T.ravel().reshape(len(x.ravel()), 3) if not xn is None: normals = np.vstack( (xn.ravel(), yn.ravel(), zn.ravel())).T.ravel().reshape(len(x.ravel()), 3) else: normals = -0.69 * np.ones(vertices.shape) self._bbox = np.array( [x.min(), y.min(), z.min(), x.max(), y.max(), z.max()]) else: vertices = None normals = None self._bbox = None if clim is not None and colors is not None and clim is not None: cs_ = ((colors - clim[0]) / (clim[1] - clim[0])) cs = cmap(cs_) cs[:, 3] = alpha cs = cs.ravel().reshape(len(colors), 4) else: #cs = None if not vertices is None: cs = np.ones((vertices.shape[0], 4), 'f') else: cs = None color_map = None color_limit = None self.set_values(vertices, normals, cs, cmap, clim, alpha) def set_values(self, vertices=None, normals=None, colors=None, color_map=None, color_limit=None, alpha=None): if vertices is not None: self._vertices = vertices if normals is not None: self._normals = normals if color_map is not None: self._color_map = color_map if colors is not None: self._colors = colors if color_limit is not None: self._color_limit = color_limit if alpha is not None: self._alpha = alpha def get_vertices(self): return self._vertices def get_normals(self): return self._normals def get_colors(self): return self._colors def get_color_map(self): return self._color_map @property def colour_map(self): return self._color_map def get_color_limit(self): return self._color_limit @property def default_view(self): from traitsui.api import View, Item, Group, InstanceEditor, EnumEditor, TextEditor from PYME.ui.custom_traits_editors import HistLimitsEditor, CBEditor return View([ Group([ Item('dsname', label='Data', editor=EnumEditor(name='_datasource_choices')), ]), Item('method'), Item( 'vertexColour', editor=EnumEditor(name='_datasource_keys'), label='Colour', visible_when='cmap not in ["R", "G", "B", "C", "M","Y", "K"]'), Group( [ Item('clim', editor=HistLimitsEditor(data=self._get_cdata, update_signal=self.on_update), show_label=False), ], visible_when='cmap not in ["R", "G", "B", "C", "M","Y", "K"]'), Group( Item('cmap', label='LUT'), Item('alpha', visible_when= "method in ['pointsprites', 'transparent_points']", editor=TextEditor(auto_set=False, enter_set=True, evaluate=float)), Item('point_size', editor=TextEditor(auto_set=False, enter_set=True, evaluate=float))) ]) #buttons=['OK', 'Cancel']) def default_traits_view(self): return self.default_view
class ImageRenderLayer(EngineLayer): """ Layer for viewing images. """ # properties to show in the GUI. Note that we also inherit 'visible' from BaseLayer cmap = Enum(*cm.cmapnames, default='gray', desc='Name of colourmap used to colour faces') clim = ListFloat([0, 1], desc='How our data should be scaled prior to colour mapping') alpha = Float(1.0, desc='Tranparency') method = Enum(*ENGINES.keys(), desc='Method used to display image') dsname = CStr('output', desc='Name of the datasource within the pipeline to use as an image') channel = Int(0) slice = Int(0) z_pos = Float(0) _datasource_choices = List() _datasource_keys = List() def __init__(self, pipeline, method='image', dsname='', display_opts=None, context=None, **kwargs): EngineLayer.__init__(self, context=context, **kwargs) self._pipeline = pipeline self.engine = None self.cmap = 'gray' self._bbox = None self._do = display_opts #a dh5view display_options instance - if provided, this over-rides the the clim, cmap properties self._im_key = None # define a signal so that people can be notified when we are updated (currently used to force a redraw when # parameters change) self.on_update = dispatch.Signal() # define responses to changes in various traits #self.on_trait_change(self._update, 'vertexColour') self.on_trait_change(lambda: self.on_update.send(self), 'visible') self.on_trait_change(self.update, 'cmap, clim, alpha, dsname') self.on_trait_change(self._set_method, 'method') # update any of our traits which were passed as command line arguments self.set(**kwargs) # update datasource and method self.dsname = dsname if self.method == method: #make sure we still call _set_method even if we start with the default method self._set_method() else: self.method = method # if we were given a pipeline, connect ourselves to the onRebuild signal so that we can automatically update # ourselves if (not self._pipeline is None) and hasattr(pipeline, 'onRebuild'): self._pipeline.onRebuild.connect(self.update) @property def datasource(self): """ Return the datasource we are connected to (does not go through the pipeline for triangles_mesh). """ try: return self._pipeline.get_layer_data(self.dsname) except AttributeError: #fallback if pipeline is a dictionary return self._pipeline[self.dsname] #return self.datasource @property def _ds_class(self): # from PYME.experimental import triangle_mesh from PYME.IO import image return image.ImageStack def _set_method(self): self.engine = ENGINES[self.method](self._context) self.update() # def _update(self, *args, **kwargs): # #pass # cdata = self._get_cdata() # self.clim = [float(cdata.min()), float(cdata.max())] # self.update(*args, **kwargs) def update(self, *args, **kwargs): try: self._datasource_choices = [k for k, v in self._pipeline.dataSources.items() if isinstance(v, self._ds_class)] except AttributeError: self._datasource_choices = [k for k, v in self._pipeline.items() if isinstance(v, self._ds_class)] if not (self.engine is None or self.datasource is None): print('lw update') self.update_from_datasource(self.datasource) self.on_update.send(self) @property def bbox(self): return self._bbox def sync_to_display_opts(self, do=None): if (do is None): if not (self._do is None): do = self._do else: return o = do.Offs[self.channel] g = do.Gains[self.channel] clim = [o, o + 1.0 / g] cmap = do.cmaps[self.channel].name visible = do.show[self.channel] self.set(clim=clim, cmap=cmap, visible=visible) def update_from_datasource(self, ds): """ Parameters ---------- ds : PYME.IO.image.ImageStack object Returns ------- None """ #if self._do is not None: # Let display options (if provied) over-ride our settings (TODO - is this the right way to do this?) # o = self._do.Offs[self.channel] # g = self._do.Gains[self.channel] # clim = [o, o + 1.0/g] #self.clim = clim # cmap = self._do.cmaps[self.channel] #self.visible = self._do.show[self.channel] #else: clim = self.clim cmap = getattr(cm, self.cmap) alpha = float(self.alpha) c0, c1 = clim im_key = (self.dsname, self.slice, self.channel) if not self._im_key == im_key: self._im_key = im_key self._im = ds.data[:,:,self.slice, self.channel].astype('f4')# - c0)/(c1-c0) x0, y0, x1, y1, _, _ = ds.imgBounds.bounds self._bbox = np.array([x0, y0, 0, x1, y1, 0]) self._bounds = [x0, y0, x1, y1] self._alpha = alpha self._color_map = cmap self._color_limit = clim def get_color_map(self): return self._color_map @property def colour_map(self): return self._color_map def get_color_limit(self): return self._color_limit def _get_cdata(self): return self._im.ravel()[::20] @property def default_view(self): from traitsui.api import View, Item, Group, InstanceEditor, EnumEditor from PYME.ui.custom_traits_editors import HistLimitsEditor, CBEditor return View([Group([Item('dsname', label='Data', editor=EnumEditor(name='_datasource_choices')), ]), #Item('method'), Group([Item('clim', editor=HistLimitsEditor(data=self._get_cdata), show_label=False), ]), Group([Item('cmap', label='LUT'), Item('alpha', visible_when='method in ["flat", "tessel"]') ]) ], ) # buttons=['OK', 'Cancel']) def default_traits_view(self): return self.default_view
class EnsembleFitProfiles(ModuleBase): inputName = Input('line_profiles') fit_type = CStr(list(profile_fitters.ensemble_fitters.keys())[0]) ensemble_parameter_guess = Float(50.) hold_ensemble_parameter_constant = Bool(False) outputName = Output('fit_results') def execute(self, namespace): inp = namespace[self.inputName] # generate LineProfileHandler from tables handler = LineProfileHandler() handler._load_profiles_from_list(inp) fit_class = profile_fitters.ensemble_fitters[self.fit_type] self.fitter = fit_class(handler) if self.hold_ensemble_parameter_constant: self.fitter.fit_profiles(self.ensemble_parameter_guess) else: self.fitter.ensemble_fit(self.ensemble_parameter_guess) res = tabular.RecArraySource(self.fitter.results) # propagate metadata, if present res.mdh = MetaDataHandler.NestedClassMDHandler( getattr(inp, 'mdh', None)) res.mdh['EnsembleFitProfiles.FitType'] = self.fit_type res.mdh[ 'EnsembleFitProfiles.EnsembleParameterGuess'] = self.ensemble_parameter_guess res.mdh[ 'EnsembleFitProfiles.HoldEnsembleParamConstant'] = self.hold_ensemble_parameter_constant namespace[self.outputName] = res @property def _fitter_choices(self): return list(profile_fitters.ensemble_fitters.keys()) @property def default_view(self): from traitsui.api import View, Item from PYME.ui.custom_traits_editors import CBEditor return View(Item('inputName', editor=CBEditor(choices=self._namespace_keys)), Item('_'), Item('fit_type', editor=CBEditor(choices=self._fitter_choices)), Item('_'), Item('ensemble_parameter_guess'), Item('_'), Item('hold_ensemble_parameter_constant'), Item('_'), Item('outputName'), buttons=['OK']) @property def pipeline_view(self): from traitsui.api import View, Item from PYME.ui.custom_traits_editors import CBEditor return View( Item('fit_type', editor=CBEditor(choices=self._fitter_choices)), Item('_'), Item('ensemble_parameter_guess'), Item('_'), Item('hold_ensemble_parameter_constant'), ) @property def dsview_view(self): from traitsui.api import View, Item from PYME.ui.custom_traits_editors import CBEditor return View(Item('fit_type', editor=CBEditor(choices=self._fitter_choices)), Item('_'), Item('ensemble_parameter_guess'), Item('_'), Item('hold_ensemble_parameter_constant'), buttons=['OK'])
class PreprocessingFilter(CacheCleanupModule): """ Optional. Combines a few image processing operations. Only 3D. 1. Applies median filter for denoising. 2. Replaces out of range values to defined values. 3. Applies 2D Tukey filter to dampen potential edge artifacts. Inputs ------ input_name : ImageStack Outputs ------- output_name : ImageStack Parameters ---------- median_filter_size : int Median filter size (``scipy.ndimage.median_filter``). threshold_lower : Float Pixels at this value or lower are replaced. clip_to_lower : Float Pixels below the lower threshold are replaced by this value. threshold_upper : Float Pixels at this value or higher are replaced. clip_to_upper : Float Pixels above the upper threshold are replaced by this value. tukey_size : float Shape parameter for Tukey filter (``scipy.signal.tukey``). cache_clip : File Use file as disk cache if provided. """ input_name = Input('input') threshold_lower = Float(0) clip_to_lower = Float(0) threshold_upper = Float(65535) clip_to_upper = Float(0) median_filter_size = Int(3) tukey_size = Float(0.25) cache_clip = File("clip_cache.bin") output_name = Output('clipped_images') def _execute(self, namespace): self._start_time = time.time() ims = namespace[self.input_name] dtype = ims.data[:, :, 0].dtype # Somewhat arbitrary way to decide on chunk size chunk_size = 100000000 / ims.data.shape[0] / ims.data.shape[ 1] / dtype.itemsize chunk_size = max(1, chunk_size) chunk_size = int(chunk_size) # print chunk_size tukey_mask_x = signal.tukey(ims.data.shape[0], self.tukey_size) tukey_mask_y = signal.tukey(ims.data.shape[1], self.tukey_size) self._tukey_mask_2d = np.multiply( *np.meshgrid(tukey_mask_x, tukey_mask_y, indexing='ij'))[:, :, None] if self.cache_clip == "": raw_data = np.empty(tuple( np.asarray(ims.data.shape[:3], dtype=np.long)), dtype=dtype) else: raw_data = np.memmap(self.cache_clip, dtype=dtype, mode='w+', shape=tuple( np.asarray(ims.data.shape[:3], dtype=np.long))) progress = 0.2 * ims.data.shape[2] for f in np.arange(0, ims.data.shape[2], chunk_size): raw_data[:, :, f:f + chunk_size] = self.applyFilter( ims.data[:, :, f:f + chunk_size]) if (f + chunk_size >= progress): if isinstance(raw_data, np.memmap): raw_data.flush() progress += 0.2 * ims.data.shape[2] print("{:.2f} s. Completed clipping {} of {} total images.". format(time.time() - self._start_time, min(f + chunk_size, ims.data.shape[2]), ims.data.shape[2])) clipped_images = ImageStack(raw_data, mdh=ims.mdh) self.completeMetadata(clipped_images) namespace[self.output_name] = clipped_images def applyFilter(self, data): """ Performs the actual filtering here. """ if self.median_filter_size > 0: data = ndimage.median_filter(data, self.median_filter_size, mode='nearest') data[data >= self.threshold_upper] = self.clip_to_upper data[data <= self.threshold_lower] = self.clip_to_lower data -= self.clip_to_lower if self.tukey_size > 0: data = data * self._tukey_mask_2d return data def completeMetadata(self, im): im.mdh['Processing.Clipping.LowerBounds'] = self.threshold_lower im.mdh['Processing.Clipping.LowerSetValue'] = self.clip_to_lower im.mdh['Processing.Clipping.UpperBounds'] = self.threshold_upper im.mdh['Processing.Clipping.UpperSetValue'] = self.clip_to_upper im.mdh['Processing.Tukey.Size'] = self.tukey_size
class RGBImageUpload(ImageUpload): """ Create RGB (png) image and upload it to an OMERO server, optionally attaching localization files. Parameters ---------- input_image : str name of image in the recipe namespace to upload input_localization_attachments : dict maps tabular types (keys) to attachment filenames (values). Tabular's will be saved as '.hdf' files and attached to the image filePattern : str pattern to determine name of image on OMERO server omero_dataset : str name of OMERO dataset to add the image to. If the dataset does not already exist it will be created. omero_project : str name of OMERO project to link the dataset to. If the project does not already exist it will be created. zoom : float how large to zoom the image scaling : str how to scale the intensity - one of 'min-max' or 'percentile' scaling_factor: float `percentile` scaling only - which percentile to use colorblind_friendly : bool Use cyan, magenta, and yellow rather than RGB. True, by default. Notes ----- OMERO server address and user login information must be stored in the user PYME config directory under plugins/config/pyme-omero, e.g. /Users/Andrew/.PYME/plugins/config/pyme-omero. The file should be yaml formatted with the following keys: user password address port [optional] The project/image/dataset will be owned by the user set in the yaml file. """ filePattern = '{file_stub}.png' scaling = Enum(['min-max', 'percentile']) scaling_factor = Float(0.95) zoom = Int(1) colorblind_friendly = Bool(True) def _save(self, image, path): from PIL import Image from PYME.IO.rgb_image import image_to_rgb, image_to_cmy if (self.colorblind_friendly and (image.data.shape[3] != 1)): im = image_to_cmy(image, zoom=self.zoom, scaling=self.scaling, scaling_factor=self.scaling_factor) else: im = image_to_rgb(image, zoom=self.zoom, scaling=self.scaling, scaling_factor=self.scaling_factor) rgb = Image.fromarray(im, mode='RGB') rgb.save(path)
class RCCDriftCorrection(RCCDriftCorrectionBase): """ For localization data. Performs drift correction using cross-correlation, including redundant RCC from Wang et al. Optics Express 2014 22:13 (Bo Huang's RCC algorithm). Runtime will vary hugely depending on size of dataset and settings. ``cache_fft`` is necessary for large datasets. Inputs ------ input_for_correction : Tabular Dataset used to calculate drift. input_for_mapping : Tabular *Deprecated.* Dataset to correct. Outputs ------- output_drift : Tuple of arrays Drift results. output_drift_plot : Plot *Deprecated.* Plot of drift results. output_cross_cor : ImageStack Cross correlation images if ``debug_cor_file`` is not blank. outputName : Tabular *Deprecated.* Drift-corrected dataset. Parameters ---------- step : Int Setting for image construction. Step size between images window : Int Setting for image construction. Number of frames used per image. Should be equal or larger than step size. binsize : Float Setting for image construction. Pixel size. flatten_z : Bool Setting for image construction. Ignore z information if enabled. tukey_size : Float Setting for image construction. Shape parameter for Tukey filter (``scipy.signal.tukey``). cache_fft : File Use file as disk cache if provided. method : String Redundant, mean, or direct cross-correlation. shift_max : Float Rejection threshold for RCC. corr_window : Float Size of correlation window. Frames are only compared if within this frame range. N/A for DCC. multiprocessing : Float Enables multiprocessing. debug_cor_file : File Enables debugging. Use file as disk cache if provided. """ input_for_correction = Input('Localizations') input_for_mapping = Input('Localizations') # redundant cross-corelation, mean cross-correlation, direction cross-correlation step = Int(2500) window = Int(2500) binsize = Float(30) flatten_z = Bool() tukey_size = Float(0.25) outputName = Output('corrected_localizations') def calc_corr_drift_from_locs(self, x, y, z, t): # bin edges for histogram bx = np.arange(x.min(), x.max() + self.binsize + 1, self.binsize) by = np.arange(y.min(), y.max() + self.binsize + 1, self.binsize) bz = np.arange(z.min(), z.max() + self.binsize + 1, self.binsize) # pad bin length to odd number so image size is even if bx.shape[0] % 2 == 0: bx = np.concatenate([bx, [bx[-1] + bx[1] - bx[0]]]) if by.shape[0] % 2 == 0: by = np.concatenate([by, [by[-1] + by[1] - by[0]]]) if bz.shape[0] > 2 and bz.shape[0] % 2 == 0: bz = np.concatenate([bz, [bz[-1] + bz[1] - bz[0]]]) assert (bx.shape[0] % 2 == 1) and (by.shape[0] % 2 == 1), "Ops. Image not correctly padded to even size." # start time of all windows, allow partial window near end of pipeline time_values = np.arange(t.min(), t.max() + 1, self.step) # 2d array, start and end time of windows time_values = np.stack([time_values, np.clip(time_values + self.window, None, t.max())], axis=1) n_steps = time_values.shape[0] # center time of center for returning. last window may have different spacing time_values_mid = time_values.mean(axis=1) if (np.any(np.diff(t) < 0)): # in case pipeline is not sorted for whatever reason t_sort_arg = np.argsort(t) t = t[t_sort_arg] x = x[t_sort_arg] y = y[t_sort_arg] z = z[t_sort_arg] time_indexes = np.zeros_like(time_values, dtype=int) time_indexes[:, 0] = np.searchsorted(t, time_values[:, 0], side='left') time_indexes[:, 1] = np.searchsorted(t, time_values[:, 1]-1, side='right') # print('time indexes') # print(time_values) # print(time_values_mid) # print(time_indexes) # Fourier transformed (and binned) set of images to correlate against # one another # Crude way of swaping longest axis to the last for optimizing rfft performance. # Code changed for this is limited to this method. xyz = np.asarray([x, y, z]) bxyz = np.asarray([bx, by, bz]) dims_order = np.arange(len(xyz)) dims_length = np.asarray([len(b) for b in bxyz]) dims_largest_index = np.argmax(dims_length) dims_order[-1], dims_order[dims_largest_index] = dims_order[dims_largest_index], dims_order[-1] xyz = xyz[dims_order] bxyz = bxyz[dims_order] dims_length = dims_length[dims_order] # use memmap for caching if ft_cache is defined if self.cache_fft == "": ft_images = np.zeros((n_steps, dims_length[0]-1, dims_length[1]-1, (dims_length[2]-1)//2 + 1, ), dtype=np.complex) else: ft_images = np.memmap(self.cache_fft, dtype=np.complex, mode='w+', shape=(n_steps, dims_length[0]-1, dims_length[1]-1, (dims_length[2]-1)//2 + 1, )) print(ft_images.shape) print("{:,} bytes".format(ft_images.nbytes)) print("{:.2f} s. About to start heavy lifting.".format(time.time() - self._start_time)) # fill ft_images # if multiprocessing, can either use or not caching # if not multiprocessing, don't pass filenames for caching, just the memmap array is fine if self.multiprocessing: dt = ft_images.dtype sh = ft_images.shape args = [(i, xyz[:,slice(*ti)].T, bxyz, (self.cache_fft, dt, sh, i), self.tukey_size) for i, ti in enumerate(time_indexes)] for i, (j, res) in enumerate(self._pool.imap_unordered(calc_fft_from_locs_helper, args)): if self.cache_fft == "": ft_images[j] = res if ((i+1) % (n_steps//5) == 0): print("{:.2f} s. Completed calculating {} of {} total ft images.".format(time.time() - self._start_time, i+1, n_steps)) else: # For each window we wish to correlate... for i, ti in enumerate(time_indexes): # .. we generate an image and store ft of image t_slice = slice(*ti) ft_images[i] = calc_fft_from_locs(xyz[:,t_slice].T, bxyz, filter_size=self.tukey_size) if ((i+1) % (n_steps//5) == 0): print("{:.2f} s. Completed calculating {} of {} total ft images.".format(time.time() - self._start_time, i+1, n_steps)) print("{:.2f} s. Finished generating ft array.".format(time.time() - self._start_time)) print("{:,} bytes".format(ft_images.nbytes)) shifts, coefs = self.calc_corr_drift_from_ft_images(ft_images) # clean up of ft_images, potentially really large array if isinstance(ft_images, np.memmap): ft_images.flush() del ft_images # print(shifts) # print(coefs) return time_values_mid, self.binsize * shifts[:, dims_order], coefs def _execute(self, namespace): # from PYME.util import mProfile # self._start_time = time.time() self.trait_setq(**{"_start_time": time.time()}) print("Starting drift correction module.") if self.multiprocessing: proccess_count = np.clip(multiprocessing.cpu_count()-1, 1, None) self.trait_setq(**{"_pool": multiprocessing.Pool(processes=proccess_count)}) locs = namespace[self.input_for_correction] # mProfile.profileOn(['localisations.py', 'processing.py']) drift_res = self.calc_corr_drift_from_locs(locs['x'], locs['y'], locs['z'] * (0 if self.flatten_z else 1), locs['t']) t_shift, shifts = self.rcc(self.shift_max, *drift_res) # mProfile.profileOff() # mProfile.report() if self.multiprocessing: self._pool.close() self._pool.join() # convert frame-to-frame drift to drift from origin shifts = np.cumsum(shifts, 0) out = tabular.mappingFilter(namespace[self.input_for_mapping]) t_out = out['t'] # cubic interpolate with no smoothing dx = interpolate.CubicSpline(t_shift, shifts[:, 0])(t_out) dy = interpolate.CubicSpline(t_shift, shifts[:, 1])(t_out) dz = interpolate.CubicSpline(t_shift, shifts[:, 2])(t_out) if 'dx' in out.keys(): # getting around oddity with mappingFilter # addColumn adds a new column but also keeps the old column # __getitem__ returns the new column # but mappings usues the old column # Wrap with another level of mappingFilter so the new column becomes the 'old column' out.addColumn('dx', dx) out.addColumn('dy', dy) out.addColumn('dz', dz) out = tabular.mappingFilter(out) # out.mdh = namespace[self.input_localizations].mdh out.setMapping('x', 'x + dx') out.setMapping('y', 'y + dy') out.setMapping('z', 'z + dz') else: out.addColumn('dx', dx) out.addColumn('dy', dy) out.addColumn('dz', dz) out.setMapping('x', 'x + dx') out.setMapping('y', 'y + dy') out.setMapping('z', 'z + dz') # propagate metadata, if present try: out.mdh = locs.mdh except AttributeError: pass namespace[self.outputName] = out namespace[self.output_drift] = t_shift, shifts # non essential, only for plotting out drift data namespace[self.output_drift_plot] = Plot(partial(generate_drift_plot, t_shift, shifts)) namespace[self.output_cross_cor] = self._cc_image
class RCCDriftCorrectionBase(CacheCleanupModule): """ Performs drift correction using redundant cross-correlation from Wang et al. Optics Express 2014 22:13 (Bo Huang's RCC algorithm). Base class for other RCC recipes. Can take cached fft input (as filename, not an 'input'). Only output drift as tuple of time points, and drift amount. Currently not registered by itself since not very usefule. """ cache_fft = File("rcc_cache.bin") method = Enum(['RCC', 'MCC', 'DCC']) # redundant cross-corelation, mean cross-correlation, direct cross-correlation shift_max = Float(5) # nm corr_window = Int(5) multiprocessing = Bool() debug_cor_file = File() output_drift = Output('drift') output_drift_plot = Output('drift_plot') # if debug_cor_file not blank, filled with imagestack of cross correlation output_cross_cor = Output('cross_cor') def calc_corr_drift_from_ft_images(self, ft_images): n_steps = ft_images.shape[0] # Matrix equation coefficient matrix # Shape can be predetermined based on method if self.method == "DCC": coefs_size = n_steps - 1 elif self.corr_window > 0: coefs_size = n_steps * self.corr_window - self.corr_window * ( self.corr_window + 1) // 2 else: coefs_size = n_steps * (n_steps - 1) // 2 coefs = np.zeros((coefs_size, n_steps - 1)) shifts = np.zeros((coefs_size, 3)) counter = 0 ft_1_cache = list() ft_2_cache = list() autocor_shift_cache = list() # print self.debug_cor_file if not self.debug_cor_file == "": cc_file_shape = [ shifts.shape[0], ft_images.shape[1], ft_images.shape[2], (ft_images.shape[3] - 1) * 2 ] # flatten shortest dimension to reduce cross correlation to 2d images for easier debugging min_arg = min(enumerate(cc_file_shape[1:]), key=lambda x: x[1])[0] + 1 cc_file_shape.pop(min_arg) cc_file_args = (self.debug_cor_file, np.float, tuple(cc_file_shape)) cc_file = np.memmap(cc_file_args[0], dtype=cc_file_args[1], mode="w+", shape=cc_file_args[2]) # del cc_file cc_args = zip(range(shifts.shape[0]), (cc_file_args, ) * shifts.shape[0]) else: cc_args = (None, ) * shifts.shape[0] # For each ft image, calculate correlation for i in np.arange(0, n_steps - 1): if self.method == "DCC" and i > 0: break ft_1 = ft_images[i, :, :] autocor_shift = calc_shift(ft_1, ft_1) for j in np.arange(i + 1, n_steps): if (self.method != "DCC") and (self.corr_window > 0) and ( j - i > self.corr_window): break ft_2 = ft_images[j, :, :] coefs[counter, i:j] = 1 # if multiprocessing, use cache when defined if self.multiprocessing: # if reading ft_images from cache, replace ft_1 and ft_2 with their indices if not self.cache_fft == "": ft_1 = i ft_2 = j ft_1_cache.append(ft_1) ft_2_cache.append(ft_2) autocor_shift_cache.append(autocor_shift) else: shifts[counter, :] = calc_shift(ft_1, ft_2, autocor_shift, None, cc_args[counter]) if ((counter + 1) % max(coefs_size // 5, 1) == 0): print( "{:.2f} s. Completed calculating {} of {} total shifts." .format(time.time() - self._start_time, counter + 1, coefs_size)) counter += 1 if self.multiprocessing: args = zip( range(len(autocor_shift_cache)), ft_1_cache, ft_2_cache, autocor_shift_cache, len(ft_1_cache) * ((self.cache_fft, ft_images.dtype, ft_images.shape), ), cc_args) for i, (j, res) in enumerate( self._pool.imap_unordered(calc_shift_helper, args)): shifts[j, ] = res if ((i + 1) % max(coefs_size // 5, 1) == 0): print( "{:.2f} s. Completed calculating {} of {} total shifts." .format(time.time() - self._start_time, i + 1, coefs_size)) print("{:.2f} s. Finished calculating all shifts.".format( time.time() - self._start_time)) print("{:,} bytes".format(coefs.nbytes)) print("{:,} bytes".format(shifts.nbytes)) if not self.debug_cor_file == "": # move time axis for ImageStack cc_file = np.moveaxis(cc_file, 0, 2) self.trait_setq(**{"_cc_image": ImageStack(data=cc_file.copy())}) del cc_file else: self.trait_setq(**{"_cc_image": None}) assert (np.all(np.any( coefs, axis=1))), "Coefficient matrix filled less than expected." mask = np.where(~np.isnan(shifts).any(axis=1))[0] if len(mask) < shifts.shape[0]: print("Removed {} cross correlations due to bad/missing data?". format(shifts.shape[0] - len(mask))) coefs = coefs[mask, :] shifts = shifts[mask, :] assert (coefs.shape[0] > 0) and ( np.linalg.matrix_rank(coefs) == n_steps - 1), "Something went wrong with coefficient matrix. Not full rank." return shifts, coefs # shifts.shape[0] is n_steps - 1 def rcc( self, shift_max, t_shift, shifts, coefs, ): """ Should probably rename function. Takes cross correlation results and calculates shifts. """ print("{:.2f} s. About to start solving shifts array.".format( time.time() - self._start_time)) # Estimate drift drifts = np.matmul(np.linalg.pinv(coefs), shifts) # print(t_shift) # print(drifts) print("{:.2f} s. Done solving shifts array.".format(time.time() - self._start_time)) if self.method == "RCC": # Calculate residual errors residuals = np.matmul(coefs, drifts) - shifts residuals_dist = np.linalg.norm(residuals, axis=1) # Sort and mask residual errors residuals_arg = np.argsort(-residuals_dist) residuals_arg = residuals_arg[ residuals_dist[residuals_arg] > shift_max] # Remove coefs rows # Descending from largest residuals to small # Only if matrix remains full rank coefs_temp = np.empty_like(coefs) counter = 0 for i, index in enumerate(residuals_arg): coefs_temp[:] = coefs coefs_temp[index, :] = 0 if np.linalg.matrix_rank(coefs_temp) == coefs.shape[1]: coefs[:] = coefs_temp # print("index {} with residual of {} removed".format(index, residuals_dist[index])) counter += 1 else: print( "Could not remove all residuals over shift_max threshold." ) break print("removed {} in total".format(counter)) # Estimate drift again drifts = np.matmul(np.linalg.pinv(coefs), shifts) print("{:.2f} s. RCC completed. Repeated solving shifts array.". format(time.time() - self._start_time)) # pad with 0 drift for first time point drifts = np.pad(drifts, [[1, 0], [0, 0]], 'constant', constant_values=0) return t_shift, drifts def _execute(self, namespace): # dervied versions of RCC need to override this method # 'execute' of this RCC base class is not throughly tested as its use is probably quite limited. # from PYME.util import mProfile self._start_time = time.time() print("Starting drift correction module.") if self.multiprocessing: proccess_count = np.clip(multiprocessing.cpu_count() - 1, 1, None) self._pool = multiprocessing.Pool(processes=proccess_count) # mProfile.profileOn(['localisations.py']) drift_res = self.calc_corr_drift_from_ft_images(self.cache_fft) t_shift, shifts = self.rcc(self.shift_max, *drift_res) # mProfile.profileOff() # mProfile.report() if self.multiprocessing: self._pool.close() self._pool.join() # convert frame-to-frame drift to drift from origin shifts = np.cumsum(shifts, 0) namespace[self.output_drift] = t_shift, shifts
class OctreeRenderLayer(TriangleRenderLayer): """ Layer for viewing octrees. Takes in an octree, splits the faces into triangles, and then uses the rendering engines from PYME.LMVis.layers.triangle_mesh. """ # properties to show in the GUI. Note that we also inherit 'visible' from BaseLayer #vertexColour = CStr('', desc='Name of variable used to colour our points') #cmap = Enum(*cm.cmapnames, default='gist_rainbow', desc='Name of colourmap used to colour faces') #clim = ListFloat([0, 1], desc='How our variable should be scaled prior to colour mapping') #alpha = Float(1.0, desc='Face tranparency') depth = Int( 3, desc= 'Depth at which to render Octree. Set to -1 for dynamic depth rendering.' ) density = Float(0.0, desc='Minimum density of octree node to display.') min_points = Int(10, desc='Number of points/node to truncate octree at') #method = Enum(*ENGINES.keys(), desc='Method used to display faces') def __init__(self, pipeline, method='wireframe', dsname='', **kwargs): TriangleRenderLayer.__init__(self, pipeline, method, dsname, **kwargs) self.on_trait_change(self.update, 'depth') self.on_trait_change(self.update, 'density') self.on_trait_change(self.update, 'min_points') @property def _ds_class(self): from PYME.experimental import octree return (octree.Octree, octree.PyOctree) def update_from_datasource(self, ds): """ Opens an octree. Subdivides the faces into triangles. Feeds the triangle points/normals to update_data. Parameters ---------- ds : Octree (see PYME.experimental.octree) Returns ------- None """ nodes = ds._nodes[ds._nodes[ds._nodes['parent']]['nPoints'] >= float( self.min_points)] if self.depth > 0: # Grab the nodes at the specified depth nodes = nodes[nodes['depth'] == self.depth] box_sizes = np.ones((nodes.shape[0], 3)) * ds.box_size(self.depth) node_density = 1. * nodes['nPoints'] / np.prod(box_sizes, axis=1) nodes = nodes[node_density >= self.density] box_sizes = np.ones((nodes.shape[0], 3)) * ds.box_size(self.depth) alpha = nodes['nPoints'] / box_sizes[:, 0] elif self.depth == 0: # plot all bins nodes = nodes[nodes['nPoints'] >= 1] box_sizes = np.vstack(ds.box_size(nodes['depth'])).T alpha = nodes['nPoints'] * ((2**nodes['depth'])**3) else: # Follow the nodes until we reach a terminating node, then append this node to our list of nodes to render # Start at the 0th node # children = ds._nodes[0]['children'].tolist() # node_indices = [] # box_sizes = [] # # Do this until we've looked at the whole octree (the list of children is empty) # while children: # # Check the child # node_index = children.pop() # curr_node = ds._nodes[node_index] # # Is this a terminating node? # if np.any(curr_node['children']): # # It's not, so we'll add the children to the list # new_children = curr_node['children'][curr_node['children'] > 0] # children.extend(new_children) # else: # # Terminating node! We want to render this # node_indices.append(node_index) # box_sizes.append(ds.box_size(curr_node['depth'])) # We've followed the octree to the end, return the nodes and box sizes #nodes = ds._nodes[node_indices] #box_sizes = np.array(box_sizes) nodes = nodes[np.sum(nodes['children']) == 0] box_sizes = np.vstack(ds.box_size(nodes['depth'])).T alpha = nodes['nPoints'] * ((2.0**nodes['depth']))**3 # First we need the vertices of the cube. We find them from the center c provided and the box size (lx, ly, lz) # provided by the octree: if len(nodes) > 0: c = nodes['centre'] # center v0 = c + box_sizes * -1 / 2 # v0 = c - lx/2 - ly/2 - lz/2 v1 = c + box_sizes * [-1, -1, 1] / 2 # v1 = c - lx/2 - ly/2 + lz/2 v2 = c + box_sizes * [-1, 1, -1] / 2 # v2 = c - lx/2 + ly/2 - lz/2 v3 = c + box_sizes * [-1, 1, 1] / 2 # v3 = c - lx/2 + ly/2 + lz/2 v4 = c + box_sizes * [1, -1, -1] / 2 # v4 = c + lx/2 - ly/2 - lz/2 v5 = c + box_sizes * [1, -1, 1] / 2 # v5 = c + lx/2 - ly/2 + lz/2 v6 = c + box_sizes * [1, 1, -1] / 2 # v6 = c + lx/2 + ly/2 - lz/2 v7 = c + box_sizes / 2 # v7 = c + lx/2 + ly/2 + lz/2 # # z # ^ # | # v1 ----------v3 # /| /| # / | / | # v5----------v7 | # | | c | | # | v0---------|-v2 # | / | / # v4-----------v6---> y # / # x # # Now note that the counterclockwise triangles (when viewed straight-on) formed along the faces of the cube are: # # v0 v2 v6 # v0 v4 v5 # v1 v3 v2 # v2 v0 v1 # v3 v1 v5 # v3 v7 v6 # v4 v6 v7 # v5 v1 v0 # v5 v7 v3 # v6 v2 v3 # v6 v4 v0 # v7 v5 v4 # Concatenate vertices, interleave, restore to 3x(3N) points (3xN triangles), # and assign the points to x, y, z vectors triangle_v0 = np.vstack( (v0, v0, v1, v2, v3, v3, v4, v5, v5, v6, v6, v7)) triangle_v1 = np.vstack( (v2, v4, v3, v0, v1, v7, v6, v1, v7, v2, v4, v5)) triangle_v2 = np.vstack( (v6, v5, v2, v1, v5, v6, v7, v0, v3, v3, v0, v4)) x, y, z = np.hstack( (triangle_v0, triangle_v1, triangle_v2)).reshape(-1, 3).T # Now we create the normal as the cross product (triangle_v2 - triangle_v1) x (triangle_v0 - triangle_v1) triangle_normals = np.cross((triangle_v2 - triangle_v1), (triangle_v0 - triangle_v1)) # We copy the normals 3 times per triangle to get 3x(3N) normals to match the vertices shape xn, yn, zn = np.repeat(triangle_normals.T, 3, axis=1) alpha = self.alpha * alpha / alpha.max() alpha = (alpha[None, :] * np.ones(12)[:, None]) alpha = np.repeat(alpha.ravel(), 3) print('Octree scaled alpha range: %g, %g' % (alpha.min(), alpha.max())) # Pass the restructured data to update_data self.update_data(x, y, z, cmap=getattr(cm, self.cmap), clim=self.clim, alpha=alpha, xn=xn, yn=yn, zn=zn) else: print('No nodes for density {0}, depth {1}'.format( self.density, self.depth)) @property def default_view(self): from traitsui.api import View, Item, Group, InstanceEditor, EnumEditor from PYME.ui.custom_traits_editors import HistLimitsEditor, CBEditor return View( [ #Group([ Item('dsname', label='Data', editor=EnumEditor(name='_datasource_choices')), Item('method'), Item('depth'), Item('min_points'), #Item('vertexColour', editor=EnumEditor(name='_datasource_keys'), label='Colour'), #Group([Item('clim', editor=HistLimitsEditor(data=self._get_cdata), show_label=False), ]), Group([Item('cmap', label='LUT'), Item('alpha')]) ], )
class LayerWrapper(HasTraits): cmap = Enum(*cm.cmapnames, default='gist_rainbow') clim = ListFloat([0, 1]) alpha = Float(1.0) visible = Bool(True) method = Enum(*ENGINES.keys()) engine = Instance(layers.BaseLayer) dsname = CStr('output') def __init__(self, pipeline, method='points', ds_name='', cmap='gist_rainbow', clim=[0, 1], alpha=1.0, visible=True, method_args={}): self._pipeline = pipeline #self._namespace=getattr(pipeline, 'namespace', {}) #self.dsname = None self.engine = None self.cmap = cmap self.clim = clim self.alpha = alpha self.visible = visible self.on_update = dispatch.Signal() self.on_trait_change(lambda: self.on_update.send(self), 'visible') self.on_trait_change(self.update, 'cmap, clim, alpha, dsname') self.on_trait_change(self._set_method, 'method') #self.set_datasource(ds_name) self.dsname = ds_name self._eng_params = dict(method_args) self.method = method self._pipeline.onRebuild.connect(self.update) @property def _namespace(self): return self._pipeline.layer_datasources @property def bbox(self): return self.engine.bbox @property def colour_map(self): return self.engine.colour_map @property def data_source_names(self): names = [] #''] for k, v in self._namespace.items(): names.append(k) if isinstance(v, tabular.ColourFilter): for c in v.getColourChans(): names.append('.'.join([k, c])) return names @property def datasource(self): if self.dsname == '': return self._pipeline parts = self.dsname.split('.') if len(parts) == 2: # special case - permit access to channels using dot notation # NB: only works if our underlying datasource is a ColourFilter ds, channel = parts return self._namespace.get(ds, None).get_channel_ds(channel) else: return self._namespace.get(self.dsname, None) def _set_method(self): if self.engine: self._eng_params = self.engine.get('point_size', 'vertexColour') #print(eng_params) self.engine = ENGINES[self.method](self._context) self.engine.set(**self._eng_params) self.engine.on_trait_change(self._update, 'vertexColour') self.engine.on_trait_change(self.update) self.update() # def set_datasource(self, ds_name): # self._dsname = ds_name # # self.update() def _update(self, *args, **kwargs): cdata = self._get_cdata() self.clim = [float(cdata.min()), float(cdata.max())] #self.update(*args, **kwargs) def update(self, *args, **kwargs): print('lw update') if not (self.engine is None or self.datasource is None): self.engine.update_from_datasource(self.datasource, getattr(cm, self.cmap), self.clim, self.alpha) self.on_update.send(self) def render(self, gl_canvas): if self.visible: self.engine.render(gl_canvas) def _get_cdata(self): try: cdata = self.datasource[self.engine.vertexColour] except KeyError: cdata = np.array([0, 1]) return cdata @property def default_view(self): from traitsui.api import View, Item, Group, InstanceEditor, EnumEditor from PYME.ui.custom_traits_editors import HistLimitsEditor, CBEditor return View( [ Group([ Item('dsname', label='Data', editor=EnumEditor(values=self.data_source_names)), ]), Item('method'), #Item('_'), Group([ Item('engine', style='custom', show_label=False, editor=InstanceEditor( view=self.engine.view(self.datasource.keys()))), ]), #Item('engine.color_key', editor=CBEditor(choices=self.datasource.keys())), Group([ Item('clim', editor=HistLimitsEditor(data=self._get_cdata), show_label=False), ]), Group([ Item('cmap', label='LUT'), Item('alpha'), Item('visible') ], orientation='horizontal', layout='flow') ], ) #buttons=['OK', 'Cancel']) def default_traits_view(self): return self.default_view
class Point3DRenderLayer(VertexRenderLayer): point_size = Float(30.0) def __init__(self, x=None, y=None, z=None, colors=None, color_map=None, color_limit=None, alpha=1.0, point_size=30): VertexRenderLayer.__init__(self, x, y, z, colors, color_map, color_limit, alpha) self.point_size = point_size self.set_shader_program(DefaultShaderProgram) def render(self, gl_canvas): self.shader_program.xmin, self.shader_program.xmax = gl_canvas.bounds[ 'x'] self.shader_program.ymin, self.shader_program.ymax = gl_canvas.bounds[ 'y'] self.shader_program.zmin, self.shader_program.zmax = gl_canvas.bounds[ 'z'] self.shader_program.vmin, self.shader_program.vmax = gl_canvas.bounds[ 'v'] with self.shader_program: n_vertices = self.get_vertices().shape[0] glVertexPointerf(self.get_vertices()) glNormalPointerf(self.get_normals()) glColorPointerf(self.get_colors()) if gl_canvas: if self.point_size == 0: glPointSize(1 / gl_canvas.pixelsize) else: glPointSize(self.point_size / gl_canvas.pixelsize) else: glPointSize(self.point_size) glDrawArrays(GL_POINTS, 0, n_vertices) def get_point_size(self): warnings.warn("use the point_size property instead", DeprecationWarning) return self.point_size def set_point_size(self, point_size): warnings.warn("use the point_size property instead", DeprecationWarning) self.point_size = point_size def view(self, ds_keys): from traitsui.api import View, Item, Group from PYME.ui.custom_traits_editors import CBEditor return View([ Item('vertexColour', editor=CBEditor(choices=ds_keys), label='Colour'), Item('point_size', label='Size [nm]') ])