class RuleChain(HasTraits): post_on = Enum(POST_CHOICES) protocol = CStr('') def __init__(self, rule_factories=None, *args, **kwargs): if rule_factories is None: rule_factories = list() self.rule_factories = rule_factories HasTraits.__init__(self, *args, **kwargs)
class ClusterStats(ModuleBase): inputName = Input('with_clumps') IDkey = CStr('clumpIndex') StatMethod = Enum(['std', 'min', 'max', 'mean', 'median', 'count', 'sum']) StatKey = CStr('x') outputName = Output('withClumpStats') def execute(self, namespace): from scipy.stats import binned_statistic inp = namespace[self.inputName] mapped = tabular.mappingFilter(inp) ids = inp[self.IDkey] # I imagine this needs to be an int type key prop = inp[self.StatKey] maxid = int(ids.max()) edges = -0.5 + np.arange(maxid + 2) resstat = binned_statistic(ids, prop, statistic=self.StatMethod, bins=edges) mapped.addColumn(self.StatKey + "_" + self.StatMethod, resstat[0][ids]) # propogate metadata, if present try: mapped.mdh = inp.mdh except AttributeError: pass namespace[self.outputName] = mapped @property def _key_choices(self): #try and find the available column names try: return sorted(self._parent.namespace[self.inputName].keys()) except: return [] @property def default_view(self): from traitsui.api import View, Group, Item from PYME.ui.custom_traits_editors import CBEditor return View(Item('inputName', editor=CBEditor(choices=self._namespace_keys)), Item('_'), Item('IDkey', editor=CBEditor(choices=self._key_choices)), Item('StatKey', editor=CBEditor(choices=self._key_choices)), Item('StatMethod'), Item('_'), Item('outputName'), buttons=['OK'])
class OutputModule(ModuleBase): """ Output modules are the one exception to the recipe-module functional (no side effects) programming paradigm and are used to perform IO and save or display designated outputs/endpoints from the recipe. As such, they should act solely as a sink, and should not do any processing or write anything back to the namespace. """ filePattern = CStr('{output_dir}/{file_stub}.csv') scheme = Enum('File', 'pyme-cluster://', 'pyme-cluster:// - aggregate') def _schemafy_filename(self, out_filename): if self.scheme == 'File': return out_filename elif self.scheme == 'pyme-cluster://': from PYME.IO import clusterIO import os return os.path.join(clusterIO.local_dataroot, out_filename.lstrip('/')) elif self.scheme == 'pyme-cluster:// - aggregate': raise RuntimeError('Aggregation not suported') def _check_outputs(self): """ This function exists to help with debugging when writing a new recipe module. Over-ridden here for the special case IO modules derived from OutputModule as these are permitted to have side-effects, but not permitted to have classical outputs to the namespace and will execute (when appropriate) regardless. """ if len(self.outputs) != 0: raise RuntimeError( 'Output modules should not write anything to the namespace') def generate(self, namespace, recipe_context={}): """ Function to be called from within dh5view (rather than batch processing). Some outputs are ignored, in which case this function returns None. Parameters ---------- namespace Returns ------- """ return None def execute(self, namespace): """ Output modules be definition do nothing when executed - they act as a sink and implement a save method instead. """ pass
class ImageModuleBase(ModuleBase): # NOTE - if a derived class only supports, e.g. XY analysis, it should redefine this trait ro only include the dimensions # it supports dimensionality = Enum( 'XY', 'XYZ', 'XYZT', desc='Which image dimensions should the filter be applied to?') #processFramesIndividually = Bool(True) @property def processFramesIndividually(self): import warnings warnings.warn( 'Use dimensionality =="XY" instead to check which dimensions a filter should be applied to, chunking ' 'hints for computational optimisation', DeprecationWarning, stacklevel=2) logger.warning( 'Use dimensionality =="XY" instead to check which dimensions a filter should be applied to, chunking ' 'hints for computational optimisation') return self.dimensionality == 'XY' @processFramesIndividually.setter def processFramesIndividually(self, value): import warnings warnings.warn( 'Use dimensionality ="XY" instead to check which dimensions a filter should be applied to, chunking ' 'hints for computational optimisation', DeprecationWarning, stacklevel=2) logger.warning( 'Use dimensionality ="XY" instead to check which dimensions a filter should be applied to, chunking ' 'hints for computational optimisation') if value: self.dimensionality = 'XY' else: self.dimensionality = 'XYZ'
class LabelByRegionProperty(Filter): """Asigns a region property to each contiguous region in the input mask. Optionally throws away all regions for which property is outside a given range. """ regionProperty = Enum(['area', 'circularity', 'aspectratio']) filterByProperty = Bool(False) propertyMin = Float(0) propertyMax = Float(1e6) def applyFilter(self, data, chanNum, frNum, im): mask = data > 0.5 labs, nlabs = ndimage.label(mask) rp = skimage.measure.regionprops(labs, None, cache=True) m2 = np.zeros_like(mask, dtype='float') objs = ndimage.find_objects(labs) for region in rp: oslices = objs[region.label - 1] r = labs[oslices] == region.label #print r.shape if self.regionProperty == 'area': propValue = region.area elif self.regionProperty == 'aspectratio': propValue = region.major_axis_length / region.minor_axis_length elif self.regionProperty == 'circularity': propValue = 4 * math.pi * region.area / (region.perimeter * region.perimeter) if self.filterByProperty: if (propValue >= self.propertyMin) and (propValue <= self.propertyMax): m2[oslices] += r * propValue else: m2[oslices] += r * propValue return m2 def completeMetadata(self, im): im.mdh['Labelling.Property'] = self.regionProperty
class FlexiThreshold(Filter): """Chose a threshold using a range of available thresholding methods. Currently we can chose from: simple, fractional, otsu, isodata """ method = Enum( 'simple', 'fractional', 'otsu', 'isodata', 'li', 'yen') # newer skimage has minimum, mean and triangle as well parameter = Float(0.5) clipAt = Float( 2e6 ) # used to be 10 - increase to large value for newer PYME renderings def fractionalThreshold(self, data): N, bins = np.histogram(data, bins=5000) #calculate bin centres bin_mids = (bins[:-1]) cN = np.cumsum(N * bin_mids) i = np.argmin(abs(cN - cN[-1] * (1 - self.parameter))) threshold = bins[i] return threshold def applyFilter(self, data, chanNum, frNum, im): if self.method == 'fractional': threshold = self.fractionalThreshold( np.clip(data, None, self.clipAt)) elif self.method == 'simple': threshold = self.parameter else: method = getattr(skf, 'threshold_%s' % self.method) threshold = method(np.clip(data, None, self.clipAt)) mask = data > threshold return mask def completeMetadata(self, im): im.mdh['Processing.ThresholdParameter'] = self.parameter im.mdh['Processing.ThresholdMethod'] = self.method
class PointCloudRenderLayer(EngineLayer): """ A layer for viewing point-cloud data, using one of 3 engines (indicated above) """ # properties to show in the GUI. Note that we also inherit 'visible' from BaseLayer vertexColour = CStr('', desc='Name of variable used to colour our points') point_size = Float(30.0, desc='Rendered size of the points in nm') cmap = Enum(*cm.cmapnames, default='gist_rainbow', desc='Name of colourmap used to colour points') clim = ListFloat( [0, 1], desc='How our variable should be scaled prior to colour mapping') alpha = Float(1.0, desc='Point tranparency') method = Enum(*ENGINES.keys(), desc='Method used to display points') dsname = CStr( 'output', desc= 'Name of the datasource within the pipeline to use as a source of points' ) _datasource_keys = List() _datasource_choices = List() def __init__(self, pipeline, method='points', dsname='', context=None, **kwargs): EngineLayer.__init__(self, context=context, **kwargs) self._pipeline = pipeline self.engine = None self.cmap = 'gist_rainbow' self.x_key = 'x' #TODO - make these traits? self.y_key = 'y' self.z_key = 'z' self.xn_key = 'xn' self.yn_key = 'yn' self.zn_key = 'zn' self._bbox = None # define a signal so that people can be notified when we are updated (currently used to force a redraw when # parameters change) self.on_update = dispatch.Signal() # define responses to changes in various traits self.on_trait_change(self._update, 'vertexColour') self.on_trait_change(lambda: self.on_update.send(self), 'visible') self.on_trait_change(self.update, 'cmap, clim, alpha, dsname, point_size') self.on_trait_change(self._set_method, 'method') # update any of our traits which were passed as command line arguments self.set(**kwargs) # update datasource name and method #logger.debug('Setting dsname and method') self.dsname = dsname self.method = method self._set_method() # if we were given a pipeline, connect ourselves to the onRebuild signal so that we can automatically update # ourselves if not self._pipeline is None: self._pipeline.onRebuild.connect(self.update) @property def datasource(self): """ Return the datasource we are connected to (through our dsname property). """ return self._pipeline.get_layer_data(self.dsname) def _set_method(self): #logger.debug('Setting layer method to %s' % self.method) self.engine = ENGINES[self.method](self._context) self.update() def _get_cdata(self): try: cdata = self.datasource[self.vertexColour] except KeyError: cdata = np.array([0, 1]) return cdata def _update(self, *args, **kwargs): cdata = self._get_cdata() self.clim = [float(cdata.min()), float(cdata.max())] #self.update(*args, **kwargs) def update(self, *args, **kwargs): print('lw update') self._datasource_choices = self._pipeline.layer_data_source_names if not self.datasource is None: self._datasource_keys = sorted(self.datasource.keys()) if not (self.engine is None or self.datasource is None): self.update_from_datasource(self.datasource) self.on_update.send(self) @property def bbox(self): return self._bbox def update_from_datasource(self, ds): x, y = ds[self.x_key], ds[self.y_key] if not self.z_key is None: try: z = ds[self.z_key] except KeyError: z = 0 * x else: z = 0 * x if not self.vertexColour == '': c = ds[self.vertexColour] else: c = 0 * x if self.xn_key in ds.keys(): xn, yn, zn = ds[self.xn_key], ds[self.yn_key], ds[self.zn_key] self.update_data(x, y, z, c, cmap=getattr(cm, self.cmap), clim=self.clim, alpha=self.alpha, xn=xn, yn=yn, zn=zn) else: self.update_data(x, y, z, c, cmap=getattr(cm, self.cmap), clim=self.clim, alpha=self.alpha) def update_data(self, x=None, y=None, z=None, colors=None, cmap=None, clim=None, alpha=1.0, xn=None, yn=None, zn=None): self._vertices = None self._normals = None self._colors = None self._color_map = None self._color_limit = 0 self._alpha = 0 if x is not None and y is not None and z is not None: vertices = np.vstack((x.ravel(), y.ravel(), z.ravel())) vertices = vertices.T.ravel().reshape(len(x.ravel()), 3) if not xn is None: normals = np.vstack( (xn.ravel(), yn.ravel(), zn.ravel())).T.ravel().reshape(len(x.ravel()), 3) else: normals = -0.69 * np.ones(vertices.shape) self._bbox = np.array( [x.min(), y.min(), z.min(), x.max(), y.max(), z.max()]) else: vertices = None normals = None self._bbox = None if clim is not None and colors is not None and clim is not None: cs_ = ((colors - clim[0]) / (clim[1] - clim[0])) cs = cmap(cs_) cs[:, 3] = alpha cs = cs.ravel().reshape(len(colors), 4) else: #cs = None if not vertices is None: cs = np.ones((vertices.shape[0], 4), 'f') else: cs = None color_map = None color_limit = None self.set_values(vertices, normals, cs, cmap, clim, alpha) def set_values(self, vertices=None, normals=None, colors=None, color_map=None, color_limit=None, alpha=None): if vertices is not None: self._vertices = vertices if normals is not None: self._normals = normals if color_map is not None: self._color_map = color_map if colors is not None: self._colors = colors if color_limit is not None: self._color_limit = color_limit if alpha is not None: self._alpha = alpha def get_vertices(self): return self._vertices def get_normals(self): return self._normals def get_colors(self): return self._colors def get_color_map(self): return self._color_map @property def colour_map(self): return self._color_map def get_color_limit(self): return self._color_limit @property def default_view(self): from traitsui.api import View, Item, Group, InstanceEditor, EnumEditor, TextEditor from PYME.ui.custom_traits_editors import HistLimitsEditor, CBEditor return View([ Group([ Item('dsname', label='Data', editor=EnumEditor(name='_datasource_choices')), ]), Item('method'), Item( 'vertexColour', editor=EnumEditor(name='_datasource_keys'), label='Colour', visible_when='cmap not in ["R", "G", "B", "C", "M","Y", "K"]'), Group( [ Item('clim', editor=HistLimitsEditor(data=self._get_cdata, update_signal=self.on_update), show_label=False), ], visible_when='cmap not in ["R", "G", "B", "C", "M","Y", "K"]'), Group( Item('cmap', label='LUT'), Item('alpha', visible_when= "method in ['pointsprites', 'transparent_points']", editor=TextEditor(auto_set=False, enter_set=True, evaluate=float)), Item('point_size', editor=TextEditor(auto_set=False, enter_set=True, evaluate=float))) ]) #buttons=['OK', 'Cancel']) def default_traits_view(self): return self.default_view
class RGBImageUpload(ImageUpload): """ Create RGB (png) image and upload it to an OMERO server, optionally attaching localization files. Parameters ---------- input_image : str name of image in the recipe namespace to upload input_localization_attachments : dict maps tabular types (keys) to attachment filenames (values). Tabular's will be saved as '.hdf' files and attached to the image filePattern : str pattern to determine name of image on OMERO server omero_dataset : str name of OMERO dataset to add the image to. If the dataset does not already exist it will be created. omero_project : str name of OMERO project to link the dataset to. If the project does not already exist it will be created. zoom : float how large to zoom the image scaling : str how to scale the intensity - one of 'min-max' or 'percentile' scaling_factor: float `percentile` scaling only - which percentile to use colorblind_friendly : bool Use cyan, magenta, and yellow rather than RGB. True, by default. Notes ----- OMERO server address and user login information must be stored in the user PYME config directory under plugins/config/pyme-omero, e.g. /Users/Andrew/.PYME/plugins/config/pyme-omero. The file should be yaml formatted with the following keys: user password address port [optional] The project/image/dataset will be owned by the user set in the yaml file. """ filePattern = '{file_stub}.png' scaling = Enum(['min-max', 'percentile']) scaling_factor = Float(0.95) zoom = Int(1) colorblind_friendly = Bool(True) def _save(self, image, path): from PIL import Image from PYME.IO.rgb_image import image_to_rgb, image_to_cmy if (self.colorblind_friendly and (image.data.shape[3] != 1)): im = image_to_cmy(image, zoom=self.zoom, scaling=self.scaling, scaling_factor=self.scaling_factor) else: im = image_to_rgb(image, zoom=self.zoom, scaling=self.scaling, scaling_factor=self.scaling_factor) rgb = Image.fromarray(im, mode='RGB') rgb.save(path)
class ImageUpload(OutputModule): """ Upload a PYME ImageStack to an OMERO server, optionally attaching localization files. Parameters ---------- input_image : str name of image in the recipe namespace to upload input_localization_attachments : dict maps tabular types (keys) to attachment filenames (values). Tabular's will be saved as '.hdf' files and attached to the image filePattern : str pattern to determine name of image on OMERO server. 'file_stub' will be set automatically. omero_dataset : str name of OMERO dataset to add the image to. If the dataset does not already exist it will be created. Can use sample metadata entries using {format} syntax omero_project : str name of OMERO project to link the dataset to. If the project does not already exist it will be created. Can use sample metadata entries using {format} syntax Notes ----- OMERO server address and user login information must be stored in the user PYME config directory under plugins/config/pyme-omero, e.g. /Users/Andrew/.PYME/plugins/config/pyme-omero. The file should be yaml formatted with the following keys: user password address port [optional] The project/image/dataset will be owned by the user set in the yaml file. """ input_image = Input('') input_localization_attachments = DictStrStr() filePattern = '{file_stub}.tif' scheme = Enum(['OMERO']) omero_project = CStr('') omero_dataset = CStr('{Sample.SlideRef}') def _save(self, image, path): # force tif extension path = os.path.splitext(path)[0] + '.tif' image.Save(path) def save(self, namespace, context={}): """ Parameters ---------- namespace : dict The recipe namespace context : dict Information about the source file to allow pattern substitution to generate the output name. At least 'file_stub' (which is the filename without any extension) should be resolved. """ from pyme_omero import core from tempfile import TemporaryDirectory out_filename = self.filePattern.format(**context) im = namespace[self.input_image] if hasattr(im, 'mdh'): sample = Sample() # hack around our md keys having periods in them for k in [k for k in im.mdh.keys() if k.startswith('Sample.')]: setattr(sample, k.split('Sample.')[-1], im.mdh[k]) sample_md = dict(Sample=sample) else: sample_md = {} dataset = self.omero_dataset.format(**sample_md) project = self.omero_project.format(**sample_md) with TemporaryDirectory() as temp_dir: out_filename = os.path.join(temp_dir, out_filename) self._save(im, out_filename) loc_filenames = [] for loc_key, loc_stub in self.input_localization_attachments.items( ): loc_filename = os.path.join(temp_dir, loc_stub) if os.path.splitext(loc_filename)[-1] == '': # default to hdf unless h5r is manually specified loc_filename = loc_filename + '.hdf' loc_filenames.append(loc_filename) try: mdh = namespace[loc_key].mdh except AttributeError: mdh = None namespace[loc_key].to_hdf(loc_filename, loc_key, metadata=mdh) image_id = core.upload_image_from_file(out_filename, dataset, project, loc_filenames) # if an h5r file is the principle input, upload it try: principle = os.path.join(context['input_dir'], context['file_stub']) + '.h5r' with core.local_or_named_temp_filename(principle) as f: core.connect_and_upload_file_annotation( image_id, f, namespace='pyme.localizations') except (KeyError, IOError): pass @property def inputs(self): return set(self.input_localization_attachments.keys()).union( set([self.input_image])) @property def default_view(self): import wx if wx.GetApp() is None: return None from traitsui.api import View, Item from PYME.ui.custom_traits_editors import DictChoiceStrEditor, CBEditor inputs, outputs, params = self.get_params() return View([ Item(name='input_image', editor=CBEditor(choices=self._namespace_keys)), ] + [ Item(name='input_localization_attachments', editor=DictChoiceStrEditor(choices=self._namespace_keys)), ] + [ Item('_'), ] + self._view_items(params), buttons=['OK', 'Cancel']) @property def pipeline_view(self): return self.default_view @property def no_localization_view(self): import wx if wx.GetApp() is None: return None from traitsui.api import View, Item from PYME.ui.custom_traits_editors import CBEditor inputs, outputs, params = self.get_params() return View([ Item(name='input_image', editor=CBEditor(choices=self._namespace_keys)), ] + [ Item('_'), ] + self._view_items(params), buttons=['OK', 'Cancel'])
class CalculateFRCFromLocs(CalculateFRCBase): """ Generates a pair of images from localization data and calculates the fourier shell/ring correlation (FSC / FRC). Inputs ------ inputName : TabularBase Localization data. Outputs ------- outputName : TabularBase Localization data labeled with how it was divided (FRC_group). output_images : ImageStack Pair of 2D or 3D histogram rendered images. output_fft_images_cc : ImageStack Fast Fourier transform original and cross-correlation images. output_frc_dict : dict FSC/FRC results. output_frc_plot : Plot Output plot of the FSC / FRC curve. output_frc_raw : dict Complete FSC/FRC results. Parameters ---------- split_method : string Different methods of dividing data into halves. pixel_size_in_nm : float Pixel size used for rendering the images. flatten_z : Bool If enabled ignores z information and only performs a FRC. pre_filter : string Methods to filter the images prior to Fourier transform. frc_smoothing_func : string Methods to smooth the FSC / FRC curve. cubic_smoothing : float Smoothing factor for cubic spline. multiprocessing : Bool Enables multiprocessing. save_path : File (Optional) File path to save output """ inputName = Input('Localizations') split_method = Enum(['halves_random', 'halves_time', 'halves_100_time_chunk', 'halves_10_time_chunk', 'fixed_time', 'fixed_10_time_chunk']) pixel_size_in_nm = Float(5) # pre_filter = Enum(['Tukey_1/8', None]) # frc_smoothing_func = Enum(['Cubic Spline', 'Sigmoid', None]) # plot_graphs = Bool() # multiprocessing = Bool(True) flatten_z = Bool(True) outputName = Output('FRC_ds') output_images = Output('FRC_images') # output_fft_images = Output('FRC_fft_images') # output_frc_dict = Output('FRC_dict') # output_frc_plot = Output('FRC_plot') def execute(self, namespace): self._namespace = namespace import multiprocessing # from PYME.util import mProfile # mProfile.profileOn(["frc.py"]) if self.multiprocessing: proccess_count = np.clip(2, 1, multiprocessing.cpu_count()-1) self._pool = multiprocessing.Pool(processes=proccess_count) pipeline = namespace[self.inputName] mapped_pipeline = tabular.mappingFilter(pipeline) self._pixel_size_in_nm = self.pixel_size_in_nm * np.ones(3, dtype=np.float) image_pair = self.generate_image_pair(mapped_pipeline) image_pair = self.preprocess_images(image_pair) # Should use DensityMapping recipe eventually when it is ready. mdh = MetaDataHandler.NestedClassMDHandler() mdh['Rendering.Method'] = "np.histogramdd" if 'imageID' in pipeline.mdh.getEntryNames(): mdh['Rendering.SourceImageID'] = pipeline.mdh['imageID'] try: mdh['Rendering.SourceFilename'] = pipeline.resultsSource.h5f.filename except: pass mdh.Source = MetaDataHandler.NestedClassMDHandler(pipeline.mdh) mdh['Rendering.NEventsRendered'] = [image_pair[0].sum(), image_pair[1].sum()] mdh['voxelsize.units'] = 'um' mdh['voxelsize.x'] = self.pixel_size_in_nm * 1E-3 mdh['voxelsize.y'] = self.pixel_size_in_nm * 1E-3 ims = ImageStack(data=np.stack(image_pair, axis=-1), mdh=mdh) namespace[self.output_images] = ims # if self.plot_graphs: # from PYME.DSView.dsviewer import ViewIm3D # ViewIm3D(ims) frc_res, rawdata = self.calculate_FRC_from_images(image_pair, pipeline.mdh) # smoothed_frc = self.SmoothFRC(frc_freq, frc_corr) # # self.CalculateThreshold(frc_freq, frc_corr, smoothed_frc) namespace[self.output_frc_dict] = frc_res namespace[self.output_frc_raw] = rawdata if self.multiprocessing: self._pool.close() self._pool.join() # mProfile.profileOff() # mProfile.report() self.save_to_file(namespace) def generate_image_pair(self, mapped_pipeline): # Split localizations into 2 sets mask = np.zeros(mapped_pipeline['t'].shape, dtype=np.bool) if self.split_method == 'halves_time': sort_arg = np.argsort(mapped_pipeline['t']) mask[sort_arg[:len(sort_arg)//2]] = 1 elif self.split_method == 'halves_random': mask[:len(mask)//2] = 1 np.random.shuffle(mask) elif self.split_method == 'halves_100_time_chunk': sort_arg = np.argsort(mapped_pipeline['t']) chunksize = mask.shape[0] / 100.0 for i in range(50): mask[sort_arg[int(np.round(i*2*chunksize)):int(np.round((i*2+1)*chunksize))]] = 1 elif self.split_method == 'halves_10_time_chunk': sort_arg = np.argsort(mapped_pipeline['t']) chunksize = mask.shape[0] * 0.1 for i in range(5): mask[sort_arg[int(np.round(i*2*chunksize)):int(np.round((i*2+1)*chunksize))]] = 1 elif self.split_method == 'fixed_time': time_cutoff = (mapped_pipeline['t'].ptp() + 1) // 2 + mapped_pipeline['t'].min() mask[mapped_pipeline['t'] < time_cutoff] = 1 elif self.split_method == 'fixed_10_time_chunk': time_cutoffs = np.linspace(mapped_pipeline['t'].min(), mapped_pipeline['t'].max(), 11, dtype=np.float) for i in range((time_cutoffs.shape[0]-1)//2): mask[(mapped_pipeline['t'] > time_cutoffs[i*2]) & (mapped_pipeline['t'] < time_cutoffs[i*2+1])] = 1 if self.split_method.startswith('halves_'): assert np.abs(mask.sum() - (~mask).sum()) <= 1, "datasets uneven, {} vs {}".format(mask.sum(), (~mask).sum()) else: print("variable counts between images, {} vs {}".format(mask.sum(), (~mask).sum())) mapped_pipeline.addColumn("FRC_group", mask) self._namespace[self.outputName] = mapped_pipeline dims = ['x', 'y', 'z'] if self.flatten_z: dims.remove('z') # Simple hist2d binning bins = list() data = list() for i, dim in enumerate(dims): bins.append(np.arange(np.floor(mapped_pipeline[dim].min()).astype(int), np.ceil(mapped_pipeline[dim].max()).astype(int) + self.pixel_size_in_nm, self.pixel_size_in_nm)) data.append(mapped_pipeline[dim]) data = np.stack(data, axis=1) # print(bins) if self.multiprocessing: results = list() results.append(self._pool.apply_async(np.histogramdd, (data[mask==0],), kwds= {"bins":bins})) results.append(self._pool.apply_async(np.histogramdd, (data[mask==1],), kwds= {"bins":bins})) image_a = results[0].get()[0] image_b = results[1].get()[0] else: image_a = np.histogramdd(data[mask==0], bins=bins)[0] image_b = np.histogramdd(data[mask==1], bins=bins)[0] return image_a, image_b
class TriangleRenderLayer(EngineLayer): """ Layer for viewing triangle meshes. """ # properties to show in the GUI. Note that we also inherit 'visible' from BaseLayer vertexColour = CStr('constant', desc='Name of variable used to colour our points') cmap = Enum(*cm.cmapnames, default='gist_rainbow', desc='Name of colourmap used to colour faces') clim = ListFloat([0, 1], desc='How our variable should be scaled prior to colour mapping') alpha = Float(1.0, desc='Face tranparency') method = Enum(*ENGINES.keys(), desc='Method used to display faces') normal_mode = Enum(['Per vertex', 'Per face']) dsname = CStr('output', desc='Name of the datasource within the pipeline to use as a source of triangles (should be a TriangularMesh object)') _datasource_choices = List() _datasource_keys = List() def __init__(self, pipeline, method='wireframe', dsname='', context=None, **kwargs): EngineLayer.__init__(self, context=context, **kwargs) self._pipeline = pipeline self.engine = None self.cmap = 'gist_rainbow' self.x_key = 'x' # TODO - make these traits? self.y_key = 'y' self.z_key = 'z' self.xn_key = 'xn' self.yn_key = 'yn' self.zn_key = 'zn' self._bbox = None # define a signal so that people can be notified when we are updated (currently used to force a redraw when # parameters change) self.on_update = dispatch.Signal() # define responses to changes in various traits self.on_trait_change(self._update, 'vertexColour') self.on_trait_change(lambda: self.on_update.send(self), 'visible') self.on_trait_change(self.update, 'cmap, clim, alpha, dsname, normal_mode') self.on_trait_change(self._set_method, 'method') # update any of our traits which were passed as command line arguments self.set(**kwargs) # update datasource and method self.dsname = dsname if self.method == method: #make sure we still call _set_method even if we start with the default method self._set_method() else: self.method = method # if we were given a pipeline, connect ourselves to the onRebuild signal so that we can automatically update # ourselves if not self._pipeline is None: self._pipeline.onRebuild.connect(self.update) @property def datasource(self): """ Return the datasource we are connected to (does not go through the pipeline for triangles_mesh). """ return self._pipeline.get_layer_data(self.dsname) #return self.datasource @property def _ds_class(self): # from PYME.experimental import triangle_mesh from PYME.experimental import _triangle_mesh as triangle_mesh return triangle_mesh.TrianglesBase def _set_method(self): self.engine = ENGINES[self.method](self._context) self.update() def _get_cdata(self): try: cdata = self.datasource[self.vertexColour] except (KeyError, TypeError): cdata = np.array([0, 1]) return cdata def _update(self, *args, **kwargs): #pass cdata = self._get_cdata() self.clim = [float(cdata.min()), float(cdata.max())] self.update(*args, **kwargs) def update(self, *args, **kwargs): self._datasource_choices = [k for k, v in self._pipeline.dataSources.items() if isinstance(v, self._ds_class)] if not self.datasource is None: dks = ['constant',] if hasattr(self.datasource, 'keys'): dks = dks + sorted(self.datasource.keys()) self._datasource_keys = dks if not (self.engine is None or self.datasource is None): print('lw update') self.update_from_datasource(self.datasource) self.on_update.send(self) @property def bbox(self): return self._bbox def update_from_datasource(self, ds): """ Pulls vertices/normals from a binary STL file. See PYME.IO.FileUtils.stl for more info. Calls update_data on the input. Parameters ---------- ds : PYME.experimental.triangular_mesh.TriangularMesh object Returns ------- None """ #t = ds.vertices[ds.faces] #n = ds.vertex_normals[ds.faces] x, y, z = ds.vertices[ds.faces].reshape(-1, 3).T if self.normal_mode == 'Per vertex': xn, yn, zn = ds.vertex_normals[ds.faces].reshape(-1, 3).T else: xn, yn, zn = np.repeat(ds.face_normals.T, 3, axis=1) if self.vertexColour in ['', 'constant']: c = np.ones(len(x)) clim = [0, 1] #elif self.vertexColour == 'vertex_index': # c = np.arange(0, len(x)) else: c = ds[self.vertexColour][ds.faces].ravel() clim = self.clim cmap = getattr(cm, self.cmap) alpha = float(self.alpha) # Do we have coordinates? Concatenate into vertices. if x is not None and y is not None and z is not None: vertices = np.vstack((x.ravel(), y.ravel(), z.ravel())) self._vertices = vertices.T.ravel().reshape(len(x.ravel()), 3) if not xn is None: self._normals = np.vstack((xn.ravel(), yn.ravel(), zn.ravel())).T.ravel().reshape(len(x.ravel()), 3) else: self._normals = -0.69 * np.ones(self._vertices.shape) self._bbox = np.array([x.min(), y.min(), z.min(), x.max(), y.max(), z.max()]) else: self._bbox = None # TODO: This temporarily sets all triangles to the color red. User should be able to select color. if c is None: c = np.ones(self._vertices.shape[0]) * 255 # vector of pink if clim is not None and c is not None and cmap is not None: cs_ = ((c - clim[0]) / (clim[1] - clim[0])) cs = cmap(cs_) if self.method in ['flat', 'tessel']: alpha = cs_ * alpha cs[:, 3] = alpha if self.method == 'tessel': cs = np.power(cs, 0.333) self._colors = cs.ravel().reshape(len(c), 4) else: # cs = None if not self._vertices is None: self._colors = np.ones((self._vertices.shape[0], 4), 'f') self._alpha = alpha self._color_map = cmap self._color_limit = clim def get_vertices(self): return self._vertices def get_normals(self): return self._normals def get_colors(self): return self._colors def get_color_map(self): return self._color_map @property def colour_map(self): return self._color_map def get_color_limit(self): return self._color_limit @property def default_view(self): from traitsui.api import View, Item, Group, InstanceEditor, EnumEditor from PYME.ui.custom_traits_editors import HistLimitsEditor, CBEditor return View([Group([Item('dsname', label='Data', editor=EnumEditor(name='_datasource_choices')), ]), Item('method'), Item('normal_mode', visible_when='method=="shaded"'), Item('vertexColour', editor=EnumEditor(name='_datasource_keys'), label='Colour'), Group([Item('clim', editor=HistLimitsEditor(data=self._get_cdata), show_label=False), ], visible_when='vertexColour != "constant"'), Group([Item('cmap', label='LUT'), Item('alpha', visible_when='method in ["flat", "tessel"]') ]) ], ) # buttons=['OK', 'Cancel']) def default_traits_view(self): return self.default_view
class AlignPSF(ModuleBase): """ Align PSF stacks by redundant cross correlation. """ inputName = Input('psf_cropped') normalize_z = Bool(True) tukey = Float(0.50) rcc_tolerance = Float(5.0) z_crop_half_roi = Int(15) peak_detect = Enum(['Gaussian', 'RBF']) debug = Bool(False) output_cross_corr_images = Output('cross_cor_img') output_cross_corr_images_fitted = Output('cross_cor_img_fitted') output_images = Output('psf_aligned') def execute(self, namespace): self._namespace = namespace ims = namespace[self.inputName] # X, Y, Z, 'C' psf_stack = ims.data[:,:,:,:] z_slice = slice(psf_stack.shape[2]//2-self.z_crop_half_roi, psf_stack.shape[2]//2+self.z_crop_half_roi+1) cleaned_psf_stack = self.normalize_images(psf_stack[:,:,z_slice,:]) if self.tukey > 0: masks = [signal.tukey(dim_len, self.tukey) for dim_len in cleaned_psf_stack.shape[:3]] masks = np.product(np.meshgrid(*masks, indexing='ij'), axis=0) cleaned_psf_stack *= masks[:,:,:,None] drifts = self.calculate_shifts(cleaned_psf_stack, self.rcc_tolerance * 1E3 / ims.mdh['voxelsize.x']) # print drifts namespace[self.output_images] = ImageStack(self.shift_images(cleaned_psf_stack if self.debug else psf_stack, drifts), mdh=ims.mdh) def normalize_images(self, psf_stack): # in case it is already bg subtracted cleaned_psf_stack = np.clip(psf_stack, 0, None) # substact bg per stack cleaned_psf_stack -= cleaned_psf_stack.min(axis=(0,1,2), keepdims=True) if self.normalize_z: # normalize intensity per plane cleaned_psf_stack /= cleaned_psf_stack.max(axis=(0,1), keepdims=True) / 1.05 else: # normalize intensity per psf stack cleaned_psf_stack /= cleaned_psf_stack.max(axis=(0,1,2), keepdims=True) / 1.05 cleaned_psf_stack -= 0.05 np.clip(cleaned_psf_stack, 0, None, cleaned_psf_stack) return cleaned_psf_stack def calculate_shifts(self, psf_stack, drift_tolerance): n_steps = psf_stack.shape[3] coefs_size = n_steps * (n_steps-1) / 2 coefs = np.zeros((coefs_size, n_steps-1)) shifts = np.zeros((coefs_size, 3)) output_cross_corr_images = np.zeros((psf_stack.shape[0], psf_stack.shape[1], psf_stack.shape[2], coefs_size), dtype=np.float) output_cross_corr_images_fitted = np.zeros((psf_stack.shape[0], psf_stack.shape[1], psf_stack.shape[2], coefs_size), dtype=np.float) counter = 0 for i in np.arange(0, n_steps - 1): for j in np.arange(i+1, n_steps): coefs[counter, i:j] = 1 print "compare {} to {}".format(i, j) correlate_result = signal.correlate(psf_stack[:,:,:,i], psf_stack[:,:,:,j], mode="same") correlate_result -= correlate_result.min() correlate_result /= correlate_result.max() threshold = 0.50 correlate_result[correlate_result<threshold] = np.nan labeled_image, labeled_counts = ndimage.label(~np.isnan(correlate_result)) # print(labeled_counts) # protects against > 1 peak in the cross correlation results # shouldn't happen anyway, but at least avoid fitting a single to multi-modal data if labeled_counts > 1: max_order = np.argsort(ndimage.maximum(correlate_result, labeled_image, np.arange(labeled_counts)+1))+1 correlate_result[labeled_image!=max_order[0]] = np.nan output_cross_corr_images[:,:,:,counter] = np.nan_to_num(correlate_result) dims = list() for _, dim in enumerate(correlate_result.shape): dims.append(np.arange(dim)) dims[-1] = dims[-1] - dims[-1].mean() # peaks = np.nonzero(correlate_result==np.nanmax(correlate_result)) if self.peak_detect == "Gaussian": res = optimize.least_squares(guassian_nd_error, [1, 0, 0, 5., 0, 5., 0, 30.], args=(dims, correlate_result)) output_cross_corr_images_fitted[:,:,:,counter] = gaussian_nd(res.x, dims) # print("Gaussian") # print("chi2: {}".format(np.sum(np.square(res.fun))/(res.fun.shape[0]-8))) # print("fitted parameters: {}".format(res.x)) # res = optimize.least_squares(guassian_sq_nd_error, # [1, 0, 0, 3., 0, 3., 0, 20.], # args=(dims, correlate_result)) # output_cross_corr_images_fitted[:,:,:,counter] = gaussian_sq_nd(res.x, dims) # print("Gaussian 2") # print("chi2: {}".format(np.sum(np.square(res.fun))/(res.fun.shape[0]-8))) # print("fitted parameters: {}".format(res.x)) # # res = optimize.least_squares(lorentzian_nd_error, # [1, 0, 0, 2., 0, 2., 0, 10.], # args=(dims, correlate_result)) # output_cross_corr_images_fitted[:,:,:,counter] = lorentzian_nd(res.x, dims) # print("lorentzian") # print("chi2: {}".format(np.sum(np.square(res.fun))/(res.fun.shape[0]-8))) # print("fitted parameters: {}".format(res.x)) shifts[counter, 0] = res.x[2] shifts[counter, 1] = res.x[4] shifts[counter, 2] = res.x[6] elif self.peak_detect == "RBF": rbf_interpolator = build_rbf(dims, correlate_result) res = optimize.minimize(rbf_nd_error, [correlate_result.shape[0]*0.5, correlate_result.shape[1]*0.5, correlate_result.shape[2]*0.5], args=rbf_interpolator) output_cross_corr_images_fitted[:,:,:,counter] = rbf_nd(rbf_interpolator, dims) # print(res.x) shifts[counter, :] = res.x else: raise Exception("peak founding method not recognised") # print("fitted parameters: {}".format(res.x)) counter += 1 self._namespace[self.output_cross_corr_images] = ImageStack(output_cross_corr_images) self._namespace[self.output_cross_corr_images_fitted] = ImageStack(output_cross_corr_images_fitted) drifts = np.matmul(np.linalg.pinv(coefs), shifts) residuals = np.matmul(coefs, drifts) - shifts residuals_dist = np.linalg.norm(residuals, axis=1) # shift_max = self.rcc_tolerance * 1E3 / mdh['voxelsize.x'] shift_max = drift_tolerance # Sort and mask residual errors residuals_arg = np.argsort(-residuals_dist) residuals_arg = residuals_arg[residuals_dist[residuals_arg] > shift_max] # Remove coefs rows # Descending from largest residuals to small # Only if matrix remains full rank coefs_temp = np.empty_like(coefs) counter = 0 for i, index in enumerate(residuals_arg): coefs_temp[:] = coefs coefs_temp[index, :] = 0 if np.linalg.matrix_rank(coefs_temp) == coefs.shape[1]: coefs[:] = coefs_temp # print("index {} with residual of {} removed".format(index, residuals_dist[index])) counter += 1 else: print("Could not remove all residuals over shift_max threshold.") break print("removed {} in total".format(counter)) drifts = np.matmul(np.linalg.pinv(coefs), shifts) drifts = np.pad(drifts, [[1,0],[0,0]], 'constant', constant_values=0) np.cumsum(drifts, axis=0, out=drifts) psf_stack_mean = psf_stack / psf_stack.mean(axis=(0,1,2), keepdims=True) psf_stack_mean = psf_stack_mean.mean(axis=3) psf_stack_mean *= psf_stack_mean > psf_stack_mean.max() * 0.5 center_offset = ndimage.center_of_mass(psf_stack_mean) - np.asarray(psf_stack_mean.shape)*0.5 # print(center_offset) # print drifts.shape # print stats.trim_mean(drifts, 0.25, axis=0) # drifts = drifts - stats.trim_mean(drifts, 0.25, axis=0) drifts = drifts - center_offset if True: try: # from matplotlib import pyplot fig, axes = pyplot.subplots(1, 2, figsize=(6,3)) # new_residuals = np.matmul(coefs, drifts) - shifts # new_residuals_dist = np.linalg.norm(new_residuals, axis=1) # # print new_residuals_dist # pyplot.hist(new_residuals_dist[coefs.any(axis=1)], 100) # print drifts # limits = np.max(np.abs(drifts), axis=0) axes[0].scatter(drifts[:,0], drifts[:,1], s=50) # axes[0].set_xlim(-limits[0], limits[0]) # axes[0].set_ylim(-limits[1], limits[1]) axes[0].set_xlabel('x') axes[0].set_ylabel('y') axes[1].scatter(drifts[:,0], drifts[:,2], s=50) axes[1].set_xlabel('x') axes[1].set_ylabel('z') for ax in axes: # ax.set_xlim(-1, 1) # ax.set_ylim(-1, 1) ax.axvline(0, color='red', ls='--') ax.axhline(0, color='red', ls='--') fig.tight_layout() except Exception as e: print e return drifts def shift_images(self, psf_stack, shifts): kx = (np.fft.fftfreq(psf_stack.shape[0])) ky = (np.fft.fftfreq(psf_stack.shape[1])) kz = (np.fft.fftfreq(psf_stack.shape[2])) kx, ky, kz = np.meshgrid(kx, ky, kz, indexing='ij') shifted_images = np.zeros_like(psf_stack) for i in np.arange(psf_stack.shape[3]): psf = psf_stack[:,:,:,i] ft_image = np.fft.fftn(psf) shift = shifts[i] shifted_images[:,:,:,i] = np.abs(np.fft.ifftn(ft_image*np.exp(-2j*np.pi*(kx*shift[0] + ky*shift[1] + kz*shift[2])))) # shifted_images.append(shifted_image) return shifted_images
class FiducialTrack(ModuleBase): """ Extract average fiducial track from input pipeline Parameters ---------- radiusMultiplier: this number is multiplied with error_x to obtain search radius for clustering timeWindow: the window along the time dimension used for clustering filterScale: the size of the filter kernel used to smooth the resulting average fiducial track filterMethod: enumrated choice of filter methods for smoothing operation (Gaussian, Median or Uniform kernel) Notes ----- Output is a new pipeline with added fiducial_x, fiducial_y columns """ import PYMEcs.Analysis.trackFiducials as tfs inputName = Input('filtered') radiusMultiplier = Float(5.0) timeWindow = Int(25) filterScale = Float(11) filterMethod = Enum(tfs.FILTER_FUNCS.keys()) clumpMinSize = Int(50) singleFiducial = Bool(True) outputName = Output('fiducialAdded') def execute(self, namespace): import PYMEcs.Analysis.trackFiducials as tfs inp = namespace[self.inputName] mapped = tabular.mappingFilter(inp) if self.singleFiducial: # if all data is from a single fiducial we do not need to align # we then avoid problems with incomplete tracks giving rise to offsets between # fiducial track fragments align = False else: align = True t, x, y, z, isFiducial = tfs.extractTrajectoriesClump( inp, clumpRadiusVar='error_x', clumpRadiusMultiplier=self.radiusMultiplier, timeWindow=self.timeWindow, clumpMinSize=self.clumpMinSize, align=align) rawtracks = (t, x, y, z) tracks = tfs.AverageTrack(inp, rawtracks, filter=self.filterMethod, filterScale=self.filterScale, align=align) # add tracks for all calculated dims to output for dim in tracks.keys(): mapped.addColumn('fiducial_%s' % dim, tracks[dim]) mapped.addColumn('isFiducial', isFiducial) # propogate metadata, if present try: mapped.mdh = inp.mdh except AttributeError: pass namespace[self.outputName] = mapped @property def hide_in_overview(self): return ['columns']
class TrackRenderLayer(EngineLayer): """ A layer for viewing tracking data """ # properties to show in the GUI. Note that we also inherit 'visible' from BaseLayer vertexColour = CStr('', desc='Name of variable used to colour our points') cmap = Enum(*cm.cmapnames, default='gist_rainbow', desc='Name of colourmap used to colour points') clim = ListFloat( [0, 1], desc='How our variable should be scaled prior to colour mapping') alpha = Float(1.0, desc='Tranparency') line_width = Float(1.0, desc='Track line width') method = Enum(*ENGINES.keys(), desc='Method used to display tracks') clump_key = CStr('clumpIndex', desc="Name of column containing the track identifier") dsname = CStr( 'output', desc= 'Name of the datasource within the pipeline to use as a source of points' ) _datasource_keys = List() _datasource_choices = List() def __init__(self, pipeline, method='tracks', dsname='', context=None, **kwargs): EngineLayer.__init__(self, context=context, **kwargs) self._pipeline = pipeline self.engine = None self.cmap = 'gist_rainbow' self.x_key = 'x' #TODO - make these traits? self.y_key = 'y' self.z_key = 'z' self._bbox = None # define a signal so that people can be notified when we are updated (currently used to force a redraw when # parameters change) self.on_update = dispatch.Signal() # define responses to changes in various traits self.on_trait_change(self._update, 'vertexColour') self.on_trait_change(lambda: self.on_update.send(self), 'visible') self.on_trait_change(self.update, 'cmap, clim, alpha, dsname, clump_key') self.on_trait_change(self._set_method, 'method') # update any of our traits which were passed as command line arguments self.set(**kwargs) # update datasource name and method #logger.debug('Setting dsname and method') self.dsname = dsname self.method = method self._set_method() # if we were given a pipeline, connect ourselves to the onRebuild signal so that we can automatically update # ourselves if not self._pipeline is None: self._pipeline.onRebuild.connect(self.update) @property def datasource(self): """ Return the datasource we are connected to (through our dsname property). """ return self._pipeline.get_layer_data(self.dsname) def _set_method(self): #logger.debug('Setting layer method to %s' % self.method) self.engine = ENGINES[self.method](self._context) self.update() def _get_cdata(self): try: if isinstance(self.datasource, ClumpManager): cdata = [] for track in self.datasource.all: cdata.extend(track[self.vertexColour]) cdata = np.array(cdata) else: # Assume tabular dataset cdata = self.datasource[self.vertexColour] except KeyError: cdata = np.array([0, 1]) return cdata def _update(self, *args, **kwargs): cdata = self._get_cdata() self.clim = [float(np.nanmin(cdata)), float(np.nanmax(cdata))] #self.update(*args, **kwargs) def update(self, *args, **kwargs): print('lw update') self._datasource_choices = self._pipeline.layer_data_source_names if not self.datasource is None: if isinstance(self.datasource, ClumpManager): # Grab the keys from the first Track in the ClumpManager self._datasource_keys = sorted(self.datasource[0].keys()) else: # Assume we have a tabular data source self._datasource_keys = sorted(self.datasource.keys()) if not (self.engine is None or self.datasource is None): self.update_from_datasource(self.datasource) self.on_update.send(self) @property def bbox(self): return self._bbox def update_from_datasource(self, ds): if isinstance(ds, ClumpManager): x = [] y = [] z = [] c = [] self.clumpSizes = [] # Copy data from tracks. This is already in clump order # thanks to ClumpManager for track in ds.all: x.extend(track['x']) y.extend(track['y']) z.extend(track['z']) self.clumpSizes.append(track.nEvents) if not self.vertexColour == '': c.extend(track[self.vertexColour]) else: c.extend([0 for i in track['x']]) x = np.array(x) y = np.array(y) z = np.array(z) c = np.array(c) # print(x,y,z,c) # print(x.shape,y.shape,z.shape,c.shape) else: # Assume tabular data source x, y = ds[self.x_key], ds[self.y_key] if not self.z_key is None: try: z = ds[self.z_key] except KeyError: z = 0 * x else: z = 0 * x if not self.vertexColour == '': c = ds[self.vertexColour] else: c = 0 * x # Work out clump start and finish indices # TODO - optimize / precompute???? ci = ds[self.clump_key] NClumps = int(ci.max()) clist = [[] for i in range(NClumps)] for i, cl_i in enumerate(ci): clist[int(cl_i - 1)].append(i) # This and self.clumpStarts are class attributes for # compatibility with the old Track rendering layer, # PYME.LMVis.gl_render3D.TrackLayer self.clumpSizes = [len(cl_i) for cl_i in clist] #reorder x, y, z, c in clump order I = np.hstack([np.array(cl) for cl in clist]).astype(np.int) x = x[I] y = y[I] z = z[I] c = c[I] self.clumpStarts = np.cumsum([ 0, ] + self.clumpSizes) #do normal vertex stuff vertices = np.vstack((x.ravel(), y.ravel(), z.ravel())) vertices = vertices.T.ravel().reshape(len(x.ravel()), 3) self._vertices = vertices self._normals = -0.69 * np.ones(vertices.shape) self._bbox = np.array( [x.min(), y.min(), z.min(), x.max(), y.max(), z.max()]) clim = self.clim cmap = getattr(cm, self.cmap) if clim is not None: cs_ = ((c - clim[0]) / (clim[1] - clim[0])) cs = cmap(cs_) cs[:, 3] = float(self.alpha) self._colors = cs.ravel().reshape(len(c), 4) else: if not vertices is None: self._colors = np.ones((vertices.shape[0], 4), 'f') self._color_map = cmap self._color_limit = clim self._alpha = float(self.alpha) def get_vertices(self): return self._vertices def get_normals(self): return self._normals def get_colors(self): return self._colors def get_color_map(self): return self._color_map @property def colour_map(self): return self._color_map def get_color_limit(self): return self._color_limit @property def default_view(self): from traitsui.api import View, Item, Group, InstanceEditor, EnumEditor, TextEditor from PYME.ui.custom_traits_editors import HistLimitsEditor, CBEditor return View([ Group([ Item('dsname', label='Data', editor=EnumEditor(name='_datasource_choices')), ]), Item('method'), Item( 'vertexColour', editor=EnumEditor(name='_datasource_keys'), label='Colour', visible_when='cmap not in ["R", "G", "B", "C", "M","Y", "K"]'), Group( [ Item('clim', editor=HistLimitsEditor(data=self._get_cdata, update_signal=self.on_update), show_label=False), ], visible_when='cmap not in ["R", "G", "B", "C", "M","Y", "K"]'), Group( Item('cmap', label='LUT'), Item('alpha', editor=TextEditor(auto_set=False, enter_set=True, evaluate=float)), Item('line_width')) ]) #buttons=['OK', 'Cancel']) def default_traits_view(self): return self.default_view
class RCCDriftCorrectionBase(CacheCleanupModule): """ Performs drift correction using redundant cross-correlation from Wang et al. Optics Express 2014 22:13 (Bo Huang's RCC algorithm). Base class for other RCC recipes. Can take cached fft input (as filename, not an 'input'). Only output drift as tuple of time points, and drift amount. Currently not registered by itself since not very usefule. """ cache_fft = File("rcc_cache.bin") method = Enum(['RCC', 'MCC', 'DCC']) # redundant cross-corelation, mean cross-correlation, direct cross-correlation shift_max = Float(5) # nm corr_window = Int(5) multiprocessing = Bool() debug_cor_file = File() output_drift = Output('drift') output_drift_plot = Output('drift_plot') # if debug_cor_file not blank, filled with imagestack of cross correlation output_cross_cor = Output('cross_cor') def calc_corr_drift_from_ft_images(self, ft_images): n_steps = ft_images.shape[0] # Matrix equation coefficient matrix # Shape can be predetermined based on method if self.method == "DCC": coefs_size = n_steps - 1 elif self.corr_window > 0: coefs_size = n_steps * self.corr_window - self.corr_window * ( self.corr_window + 1) // 2 else: coefs_size = n_steps * (n_steps - 1) // 2 coefs = np.zeros((coefs_size, n_steps - 1)) shifts = np.zeros((coefs_size, 3)) counter = 0 ft_1_cache = list() ft_2_cache = list() autocor_shift_cache = list() # print self.debug_cor_file if not self.debug_cor_file == "": cc_file_shape = [ shifts.shape[0], ft_images.shape[1], ft_images.shape[2], (ft_images.shape[3] - 1) * 2 ] # flatten shortest dimension to reduce cross correlation to 2d images for easier debugging min_arg = min(enumerate(cc_file_shape[1:]), key=lambda x: x[1])[0] + 1 cc_file_shape.pop(min_arg) cc_file_args = (self.debug_cor_file, np.float, tuple(cc_file_shape)) cc_file = np.memmap(cc_file_args[0], dtype=cc_file_args[1], mode="w+", shape=cc_file_args[2]) # del cc_file cc_args = zip(range(shifts.shape[0]), (cc_file_args, ) * shifts.shape[0]) else: cc_args = (None, ) * shifts.shape[0] # For each ft image, calculate correlation for i in np.arange(0, n_steps - 1): if self.method == "DCC" and i > 0: break ft_1 = ft_images[i, :, :] autocor_shift = calc_shift(ft_1, ft_1) for j in np.arange(i + 1, n_steps): if (self.method != "DCC") and (self.corr_window > 0) and ( j - i > self.corr_window): break ft_2 = ft_images[j, :, :] coefs[counter, i:j] = 1 # if multiprocessing, use cache when defined if self.multiprocessing: # if reading ft_images from cache, replace ft_1 and ft_2 with their indices if not self.cache_fft == "": ft_1 = i ft_2 = j ft_1_cache.append(ft_1) ft_2_cache.append(ft_2) autocor_shift_cache.append(autocor_shift) else: shifts[counter, :] = calc_shift(ft_1, ft_2, autocor_shift, None, cc_args[counter]) if ((counter + 1) % max(coefs_size // 5, 1) == 0): print( "{:.2f} s. Completed calculating {} of {} total shifts." .format(time.time() - self._start_time, counter + 1, coefs_size)) counter += 1 if self.multiprocessing: args = zip( range(len(autocor_shift_cache)), ft_1_cache, ft_2_cache, autocor_shift_cache, len(ft_1_cache) * ((self.cache_fft, ft_images.dtype, ft_images.shape), ), cc_args) for i, (j, res) in enumerate( self._pool.imap_unordered(calc_shift_helper, args)): shifts[j, ] = res if ((i + 1) % max(coefs_size // 5, 1) == 0): print( "{:.2f} s. Completed calculating {} of {} total shifts." .format(time.time() - self._start_time, i + 1, coefs_size)) print("{:.2f} s. Finished calculating all shifts.".format( time.time() - self._start_time)) print("{:,} bytes".format(coefs.nbytes)) print("{:,} bytes".format(shifts.nbytes)) if not self.debug_cor_file == "": # move time axis for ImageStack cc_file = np.moveaxis(cc_file, 0, 2) self.trait_setq(**{"_cc_image": ImageStack(data=cc_file.copy())}) del cc_file else: self.trait_setq(**{"_cc_image": None}) assert (np.all(np.any( coefs, axis=1))), "Coefficient matrix filled less than expected." mask = np.where(~np.isnan(shifts).any(axis=1))[0] if len(mask) < shifts.shape[0]: print("Removed {} cross correlations due to bad/missing data?". format(shifts.shape[0] - len(mask))) coefs = coefs[mask, :] shifts = shifts[mask, :] assert (coefs.shape[0] > 0) and ( np.linalg.matrix_rank(coefs) == n_steps - 1), "Something went wrong with coefficient matrix. Not full rank." return shifts, coefs # shifts.shape[0] is n_steps - 1 def rcc( self, shift_max, t_shift, shifts, coefs, ): """ Should probably rename function. Takes cross correlation results and calculates shifts. """ print("{:.2f} s. About to start solving shifts array.".format( time.time() - self._start_time)) # Estimate drift drifts = np.matmul(np.linalg.pinv(coefs), shifts) # print(t_shift) # print(drifts) print("{:.2f} s. Done solving shifts array.".format(time.time() - self._start_time)) if self.method == "RCC": # Calculate residual errors residuals = np.matmul(coefs, drifts) - shifts residuals_dist = np.linalg.norm(residuals, axis=1) # Sort and mask residual errors residuals_arg = np.argsort(-residuals_dist) residuals_arg = residuals_arg[ residuals_dist[residuals_arg] > shift_max] # Remove coefs rows # Descending from largest residuals to small # Only if matrix remains full rank coefs_temp = np.empty_like(coefs) counter = 0 for i, index in enumerate(residuals_arg): coefs_temp[:] = coefs coefs_temp[index, :] = 0 if np.linalg.matrix_rank(coefs_temp) == coefs.shape[1]: coefs[:] = coefs_temp # print("index {} with residual of {} removed".format(index, residuals_dist[index])) counter += 1 else: print( "Could not remove all residuals over shift_max threshold." ) break print("removed {} in total".format(counter)) # Estimate drift again drifts = np.matmul(np.linalg.pinv(coefs), shifts) print("{:.2f} s. RCC completed. Repeated solving shifts array.". format(time.time() - self._start_time)) # pad with 0 drift for first time point drifts = np.pad(drifts, [[1, 0], [0, 0]], 'constant', constant_values=0) return t_shift, drifts def _execute(self, namespace): # dervied versions of RCC need to override this method # 'execute' of this RCC base class is not throughly tested as its use is probably quite limited. # from PYME.util import mProfile self._start_time = time.time() print("Starting drift correction module.") if self.multiprocessing: proccess_count = np.clip(multiprocessing.cpu_count() - 1, 1, None) self._pool = multiprocessing.Pool(processes=proccess_count) # mProfile.profileOn(['localisations.py']) drift_res = self.calc_corr_drift_from_ft_images(self.cache_fft) t_shift, shifts = self.rcc(self.shift_max, *drift_res) # mProfile.profileOff() # mProfile.report() if self.multiprocessing: self._pool.close() self._pool.join() # convert frame-to-frame drift to drift from origin shifts = np.cumsum(shifts, 0) namespace[self.output_drift] = t_shift, shifts
class ImageRenderLayer(EngineLayer): """ Layer for viewing images. """ # properties to show in the GUI. Note that we also inherit 'visible' from BaseLayer cmap = Enum(*cm.cmapnames, default='gray', desc='Name of colourmap used to colour faces') clim = ListFloat([0, 1], desc='How our data should be scaled prior to colour mapping') alpha = Float(1.0, desc='Tranparency') method = Enum(*ENGINES.keys(), desc='Method used to display image') dsname = CStr('output', desc='Name of the datasource within the pipeline to use as an image') channel = Int(0) slice = Int(0) z_pos = Float(0) _datasource_choices = List() _datasource_keys = List() def __init__(self, pipeline, method='image', dsname='', display_opts=None, context=None, **kwargs): EngineLayer.__init__(self, context=context, **kwargs) self._pipeline = pipeline self.engine = None self.cmap = 'gray' self._bbox = None self._do = display_opts #a dh5view display_options instance - if provided, this over-rides the the clim, cmap properties self._im_key = None # define a signal so that people can be notified when we are updated (currently used to force a redraw when # parameters change) self.on_update = dispatch.Signal() # define responses to changes in various traits #self.on_trait_change(self._update, 'vertexColour') self.on_trait_change(lambda: self.on_update.send(self), 'visible') self.on_trait_change(self.update, 'cmap, clim, alpha, dsname') self.on_trait_change(self._set_method, 'method') # update any of our traits which were passed as command line arguments self.set(**kwargs) # update datasource and method self.dsname = dsname if self.method == method: #make sure we still call _set_method even if we start with the default method self._set_method() else: self.method = method # if we were given a pipeline, connect ourselves to the onRebuild signal so that we can automatically update # ourselves if (not self._pipeline is None) and hasattr(pipeline, 'onRebuild'): self._pipeline.onRebuild.connect(self.update) @property def datasource(self): """ Return the datasource we are connected to (does not go through the pipeline for triangles_mesh). """ try: return self._pipeline.get_layer_data(self.dsname) except AttributeError: #fallback if pipeline is a dictionary return self._pipeline[self.dsname] #return self.datasource @property def _ds_class(self): # from PYME.experimental import triangle_mesh from PYME.IO import image return image.ImageStack def _set_method(self): self.engine = ENGINES[self.method](self._context) self.update() # def _update(self, *args, **kwargs): # #pass # cdata = self._get_cdata() # self.clim = [float(cdata.min()), float(cdata.max())] # self.update(*args, **kwargs) def update(self, *args, **kwargs): try: self._datasource_choices = [k for k, v in self._pipeline.dataSources.items() if isinstance(v, self._ds_class)] except AttributeError: self._datasource_choices = [k for k, v in self._pipeline.items() if isinstance(v, self._ds_class)] if not (self.engine is None or self.datasource is None): print('lw update') self.update_from_datasource(self.datasource) self.on_update.send(self) @property def bbox(self): return self._bbox def sync_to_display_opts(self, do=None): if (do is None): if not (self._do is None): do = self._do else: return o = do.Offs[self.channel] g = do.Gains[self.channel] clim = [o, o + 1.0 / g] cmap = do.cmaps[self.channel].name visible = do.show[self.channel] self.set(clim=clim, cmap=cmap, visible=visible) def update_from_datasource(self, ds): """ Parameters ---------- ds : PYME.IO.image.ImageStack object Returns ------- None """ #if self._do is not None: # Let display options (if provied) over-ride our settings (TODO - is this the right way to do this?) # o = self._do.Offs[self.channel] # g = self._do.Gains[self.channel] # clim = [o, o + 1.0/g] #self.clim = clim # cmap = self._do.cmaps[self.channel] #self.visible = self._do.show[self.channel] #else: clim = self.clim cmap = getattr(cm, self.cmap) alpha = float(self.alpha) c0, c1 = clim im_key = (self.dsname, self.slice, self.channel) if not self._im_key == im_key: self._im_key = im_key self._im = ds.data[:,:,self.slice, self.channel].astype('f4')# - c0)/(c1-c0) x0, y0, x1, y1, _, _ = ds.imgBounds.bounds self._bbox = np.array([x0, y0, 0, x1, y1, 0]) self._bounds = [x0, y0, x1, y1] self._alpha = alpha self._color_map = cmap self._color_limit = clim def get_color_map(self): return self._color_map @property def colour_map(self): return self._color_map def get_color_limit(self): return self._color_limit def _get_cdata(self): return self._im.ravel()[::20] @property def default_view(self): from traitsui.api import View, Item, Group, InstanceEditor, EnumEditor from PYME.ui.custom_traits_editors import HistLimitsEditor, CBEditor return View([Group([Item('dsname', label='Data', editor=EnumEditor(name='_datasource_choices')), ]), #Item('method'), Group([Item('clim', editor=HistLimitsEditor(data=self._get_cdata), show_label=False), ]), Group([Item('cmap', label='LUT'), Item('alpha', visible_when='method in ["flat", "tessel"]') ]) ], ) # buttons=['OK', 'Cancel']) def default_traits_view(self): return self.default_view
class LayerWrapper(HasTraits): cmap = Enum(*cm.cmapnames, default='gist_rainbow') clim = ListFloat([0, 1]) alpha = Float(1.0) visible = Bool(True) method = Enum(*ENGINES.keys()) engine = Instance(layers.BaseLayer) dsname = CStr('output') def __init__(self, pipeline, method='points', ds_name='', cmap='gist_rainbow', clim=[0, 1], alpha=1.0, visible=True, method_args={}): self._pipeline = pipeline #self._namespace=getattr(pipeline, 'namespace', {}) #self.dsname = None self.engine = None self.cmap = cmap self.clim = clim self.alpha = alpha self.visible = visible self.on_update = dispatch.Signal() self.on_trait_change(lambda: self.on_update.send(self), 'visible') self.on_trait_change(self.update, 'cmap, clim, alpha, dsname') self.on_trait_change(self._set_method, 'method') #self.set_datasource(ds_name) self.dsname = ds_name self._eng_params = dict(method_args) self.method = method self._pipeline.onRebuild.connect(self.update) @property def _namespace(self): return self._pipeline.layer_datasources @property def bbox(self): return self.engine.bbox @property def colour_map(self): return self.engine.colour_map @property def data_source_names(self): names = [] #''] for k, v in self._namespace.items(): names.append(k) if isinstance(v, tabular.ColourFilter): for c in v.getColourChans(): names.append('.'.join([k, c])) return names @property def datasource(self): if self.dsname == '': return self._pipeline parts = self.dsname.split('.') if len(parts) == 2: # special case - permit access to channels using dot notation # NB: only works if our underlying datasource is a ColourFilter ds, channel = parts return self._namespace.get(ds, None).get_channel_ds(channel) else: return self._namespace.get(self.dsname, None) def _set_method(self): if self.engine: self._eng_params = self.engine.get('point_size', 'vertexColour') #print(eng_params) self.engine = ENGINES[self.method](self._context) self.engine.set(**self._eng_params) self.engine.on_trait_change(self._update, 'vertexColour') self.engine.on_trait_change(self.update) self.update() # def set_datasource(self, ds_name): # self._dsname = ds_name # # self.update() def _update(self, *args, **kwargs): cdata = self._get_cdata() self.clim = [float(cdata.min()), float(cdata.max())] #self.update(*args, **kwargs) def update(self, *args, **kwargs): print('lw update') if not (self.engine is None or self.datasource is None): self.engine.update_from_datasource(self.datasource, getattr(cm, self.cmap), self.clim, self.alpha) self.on_update.send(self) def render(self, gl_canvas): if self.visible: self.engine.render(gl_canvas) def _get_cdata(self): try: cdata = self.datasource[self.engine.vertexColour] except KeyError: cdata = np.array([0, 1]) return cdata @property def default_view(self): from traitsui.api import View, Item, Group, InstanceEditor, EnumEditor from PYME.ui.custom_traits_editors import HistLimitsEditor, CBEditor return View( [ Group([ Item('dsname', label='Data', editor=EnumEditor(values=self.data_source_names)), ]), Item('method'), #Item('_'), Group([ Item('engine', style='custom', show_label=False, editor=InstanceEditor( view=self.engine.view(self.datasource.keys()))), ]), #Item('engine.color_key', editor=CBEditor(choices=self.datasource.keys())), Group([ Item('clim', editor=HistLimitsEditor(data=self._get_cdata), show_label=False), ]), Group([ Item('cmap', label='LUT'), Item('alpha'), Item('visible') ], orientation='horizontal', layout='flow') ], ) #buttons=['OK', 'Cancel']) def default_traits_view(self): return self.default_view
class CalculateFRCBase(ModuleBase): """ Base class. Refer to derived classes for docstrings. """ pre_filter = Enum(['Tukey_1/8', None]) frc_smoothing_func = Enum(['Cubic Spline', 'Sigmoid', None]) multiprocessing = Bool(True) # plot_graphs = Bool() cubic_smoothing = Float(0.01) save_path = File() # output_fft_image_a = Output('FRC_fft_image_a') # output_fft_image_b = Output('FRC_fft_image_b') output_fft_images_cc = Output('FRC_fft_images_cc') output_frc_dict = Output('FRC_dict') output_frc_plot = Output('FRC_plot') output_frc_raw = Output('FRC_raw') def execute(self): raise Exception("Base class not fully implemented") def preprocess_images(self, image_pair): # pad images to square shape # image_pair = self.pad_images_to_equal_dims(image_pair) dims_length = np.stack([im.shape for im in image_pair], 0) assert np.all([np.all(dims_length[:, i] == dims_length[0, i])for i in range(dims_length.shape[1])]), "Images not the same dimension." # apply filtering to zero near edge of images if self.pre_filter == 'Tukey_1/8': image_pair = self.filter_images_tukey(image_pair, 1./8) elif self.pre_filter == None: pass else: raise Exception() return image_pair # def pad_images_to_equal_dims(self, images): # dims_length = np.stack([im.shape for im in images], 0) # assert np.all([np.all(dims_length[:, i] == dims_length[0, i])for i in xrange(dims_length.shape[1])]), "Images not the same dimension." # # return images # # dims_length = dims_length[0, :] # max_dim = dims_length.max() # # padding = np.empty((dims_length.shape[0], 2), dtype=np.int) # for dim in xrange(dims_length.shape[0]): # total_padding = max_dim - dims_length[dim] # padding[dim] = [total_padding // 2, total_padding - total_padding //2] # # results = list() # for im in images: # results.append(np.pad(im, padding, mode='constant', constant_values=0)) # # return results def filter_images_tukey(self, images, alpha): from scipy.signal import tukey # window = tukey(images[0].shape[0], alpha=alpha) # window_nd = np.prod(np.stack(np.meshgrid(*(window,)*images[0].ndim)), axis=0) windows = [tukey(images[0].shape[i], alpha=alpha) for i in range(images[0].ndim)] window_nd = np.prod(np.stack(np.meshgrid(*windows, indexing='ij')), axis=0) # for im in images: # im *= window_nd return [images[0]*window_nd, images[1]*window_nd] def calculate_FRC_from_images(self, image_pair, mdh): ft_images = list() if self.multiprocessing: results = list() for im in image_pair: results.append(self._pool.apply_async(np.fft.fftn, (im,))) for res in results: ft_images.append(res.get()) del results else: for im in image_pair: ft_images.append(np.fft.fftn(im)) # im_fft_freq = np.fft.fftfreq(image_pair[0].shape[0], self._pixel_size_in_nm) # im_R = np.sqrt(im_fft_freq[:, None]**2 + im_fft_freq[None, :]**2) im_fft_freqs = [np.fft.fftfreq(image_pair[0].shape[i], self._pixel_size_in_nm[i]) for i in range(image_pair[0].ndim)] im_R = np.linalg.norm(np.stack(np.meshgrid(*im_fft_freqs, indexing='ij')), axis=0) im1_fft_power = np.multiply(ft_images[0], np.conj(ft_images[0])) im2_fft_power = np.multiply(ft_images[1], np.conj(ft_images[1])) im12_fft_power = np.multiply(ft_images[0], np.conj(ft_images[1])) ## fft_ims = ImageStack(data=np.stack([np.fft.fftshift(im1_fft_power), ## np.fft.fftshift(im2_fft_power), ## np.fft.fftshift(im12_fft_power)], axis=-1), mdh=mdh) ## self._namespace[self.output_fft_images] = fft_ims # self._namespace[self.output_fft_image_a] = ImageStack(data=np.fft.fftshift(im1_fft_power), titleStub="ImageA_FFT") # self._namespace[self.output_fft_image_b] = ImageStack(data=np.fft.fftshift(im2_fft_power), titleStub="ImageB_FFT") # self._namespace[self.output_fft_images_cc] = ImageStack(data=np.fft.fftshift(im12_fft_power), titleStub="ImageA_Image_B_FFT_CC") try: self._namespace[self.output_fft_images_cc] = ImageStack(data=np.stack([np.atleast_3d(np.fft.fftshift(im1_fft_power)), np.atleast_3d(np.fft.fftshift(im2_fft_power)), np.atleast_3d(np.fft.fftshift(im12_fft_power))], 3), titleStub="ImageA_Image_FFT_CC") # if self.plot_graphs: # from PYME.DSView.dsviewer import ViewIm3D, View3D # # ViewIm3D(self._namespace[self.output_fft_image_a]) # # ViewIm3D(self._namespace[self.output_fft_image_b]) # ViewIm3D(self._namespace[self.output_fft_images_cc]) # # View3D(np.fft.fftshift(im_R)) except Exception as e: print (e) im1_fft_flat_res = CalculateFRCBase.BinData(im_R.flatten(), im1_fft_power.flatten(), statistic='mean', bins=201) im2_fft_flat_res = CalculateFRCBase.BinData(im_R.flatten(), im2_fft_power.flatten(), statistic='mean', bins=201) im12_fft_flat_res = CalculateFRCBase.BinData(im_R.flatten(), im12_fft_power.flatten(), statistic='mean', bins=201) corr = np.real(im12_fft_flat_res.statistic) / np.sqrt(np.abs(im1_fft_flat_res.statistic*im2_fft_flat_res.statistic)) smoothed_frc = self.smooth_frc(im12_fft_flat_res.bin_edges[:-1], corr, self.cubic_smoothing) res, rawdata = self.calculate_threshold(im12_fft_flat_res.bin_edges[:-1], corr, smoothed_frc, im12_fft_flat_res.counts) return res, rawdata def smooth_frc(self, freq, corr, cubic_smoothing): if self.frc_smoothing_func is None: interp_frc = interpolate.interp1d(freq, corr, kind='next', ) return interp_frc elif self.frc_smoothing_func == "Sigmoid": func = CalculateFRCBase.Sigmoid fit_res = optimize.minimize(lambda a, x: np.sum(np.square(func(a[0], a[1], x)-corr)), [1, freq[len(freq)/2]], args=(freq), method='Nelder-Mead') return partial(func, n=fit_res.x[0], c=fit_res.x[1]) elif self.frc_smoothing_func == "Cubic Spline": # smoothed so that average deviation loss is less than 0.2% of original. Somewhat arbitrary but probably not totally unreasonable since FRC is bounded 0 to 1. # interp_frc = interpolate.UnivariateSpline(freq, corr, k=3, s=len(freq)*(0.002*np.var(corr))) # interp_frc = interpolate.UnivariateSpline(freq, corr, k=3, s=(0.05*np.std(corr))) interp_frc = interpolate.UnivariateSpline(freq, corr, k=3, s=cubic_smoothing) return interp_frc def calculate_threshold(self, freq, corr, corr_func, counts): res = dict() fsc_0143 = optimize.minimize(lambda x: np.square(corr_func(x=x)-0.143), freq[np.argmax(corr_func(x=freq)-0.143 < 0)], method='Nelder-Mead') res['frc 1/7'] = 1./fsc_0143.x[0] sigma = 1.0 / np.sqrt(counts*0.5) sigma_spl = interpolate.UnivariateSpline(freq, sigma, k=3, s=0) fsc_3sigma = optimize.minimize(lambda x: np.square(corr_func(x=x)-3.*sigma_spl(x)), freq[np.argmax(corr_func(x=freq)-3.*sigma_spl(freq) < 0)], method='Nelder-Mead') res['frc 3 sigma'] = 1./fsc_3sigma.x[0] # van Heel and Schatz, 2005, Fourier shell correlation threshold criteria half_bit = (0.2071 + 1.9102 / np.sqrt(counts)) / (1.2071 + 0.9102 / np.sqrt(counts)) half_bit_spl = interpolate.UnivariateSpline(freq, half_bit, k=3, s=0) fsc_half_bit = optimize.minimize(lambda x: np.square(corr_func(x=x)-half_bit_spl(x)), freq[np.argmax(corr_func(x=freq)-half_bit_spl(freq) < 0)], method='Nelder-Mead') res['frc half bit'] = 1./fsc_half_bit.x[0] # fsc_max = np.max([fsc_0143.x[0], fsc_2sigma.x[0], fsc_3sigma.x[0], fsc_5sigma.x[0], fsc_half_bit.x[0]]) # axes[1].set_xlim(0, np.min([2*fsc_max, im12_fft_flat_res.bin_edges[-1]])) # if not self.plot_graphs: # ioff() def plot(): frc_text = "" fig, axes = subplots(1,2,figsize=(10,4)) axes[0].plot(freq, corr) axes[0].plot(freq, corr_func(x=freq)) axes[0].axhline(0.143, ls='--', color='red') axes[0].axvline(fsc_0143.x[0], ls='--', color='red', label='1/7') frc_text += "\n1/7: {:.2f} nm".format(1./fsc_0143.x[0]) axes[0].plot(freq, 3*sigma_spl(freq), ls='--', color='pink') axes[0].axvline(fsc_3sigma.x[0], ls='--', color='pink', label='3 sigma') frc_text += "\n3 sigma: {:.2f} nm".format(1./fsc_3sigma.x[0]) axes[0].plot(freq, half_bit_spl(freq), ls='--', color='purple') axes[0].axvline(fsc_half_bit.x[0], ls='--', color='purple', label='1/2 bit') frc_text += "\n1/2 bit: {:.2f} nm".format(1./fsc_half_bit.x[0]) axes[0].legend() # axes[0].set_ylim(None, 1.1) x_ticklocs = axes[0].get_xticks() axes[0].set_xticklabels(["{:.1f}".format(1./i) for i in x_ticklocs]) axes[0].set_ylabel("FSC/FRC") axes[0].set_xlabel("Resol (nm)") axes[1].text(0.5, 0.5, frc_text, horizontalalignment='center', verticalalignment='center', transform=axes[1].transAxes) axes[1].set_axis_off() # if self.plot_graphs: # fig.show() # else: # ion() plot() self._namespace[self.output_frc_plot] = Plot(plot) rawdata = {'freq':freq, 'corr':corr, 'smooth':corr_func(x=freq), '1/7':np.ones_like(freq)/7, '3 sigma':3*sigma_spl(freq), '1/2 bit':half_bit_spl(freq)} return res, rawdata def save_to_file(self, namespace): if self.save_path is not "": try: np.savez_compressed(self.save_path, raw=namespace[self.output_frc_raw], results=namespace[self.output_frc_dict]) except Exception as e: raise e @staticmethod def BinData(indexes, data, statistic='mean', bins=10): # Calculates binned statistics. Supports complex number. if statistic == 'mean': func = np.mean elif statistic == 'sum': func = np.sum class Result(object): statistic = None bin_edges = None counts = None bins = np.linspace(indexes.min(), indexes.max(), bins) binned = np.zeros(len(bins)-1, dtype=data.dtype) counts = np.zeros(len(bins)-1, dtype=np.int) indexes_sort_arg = np.argsort(indexes.flatten()) indexes_sorted = indexes.flatten()[indexes_sort_arg] data_sorted = data.flatten()[indexes_sort_arg] edges_indexes = np.searchsorted(indexes_sorted, bins) for i in range(bins.shape[0]-1): values = data_sorted[edges_indexes[i]:edges_indexes[i+1]] binned[i] = func(values) counts[i] = len(values) res = Result() res.statistic = binned res.bin_edges = bins res.counts = counts return res @staticmethod def Sigmoid(n, c, x): res = 1 - 1 / (1 + np.exp(n*(-x+c))) return res