class ExtractChannel(ModuleBase): """Extract one channel from an image""" inputName = Input('input') outputName = Output('filtered_image') channelToExtract = Int(0) def _pickChannel(self, image): chan = image.data[:, :, :, self.channelToExtract] im = ImageStack(chan, titleStub='Filtered Image') im.mdh.copyEntriesFrom(image.mdh) try: im.mdh['ChannelNames'] = [ image.names[self.channelToExtract], ] except (KeyError, AttributeError): logger.warn("Error setting channel name") im.mdh['Parent'] = image.filename return im def execute(self, namespace): namespace[self.outputName] = self._pickChannel( namespace[self.inputName])
class HistByID(ModuleBase): """Plot histogram of a column by ID""" inputName = Input('measurements') IDkey = CStr('objectID') histkey = CStr('qIndex') outputName = Output('outGraph') nbins = Int(50) minval = Float(float('nan')) maxval = Float(float('nan')) def execute(self, namespace): import math meas = namespace[self.inputName] ids = meas[self.IDkey] uid, valsu = uniqueByID(ids, meas[self.histkey]) if math.isnan(self.minval): minv = valsu.min() else: minv = self.minval if math.isnan(self.maxval): maxv = valsu.max() else: maxv = self.maxval import matplotlib.pyplot as plt plt.figure() plt.hist(valsu, self.nbins, range=(minv, maxv)) plt.xlabel(self.histkey) @property def _key_choices(self): #try and find the available column names try: return sorted(self._parent.namespace[self.inputName].keys()) except: return [] @property def default_view(self): from traitsui.api import View, Group, Item from PYME.ui.custom_traits_editors import CBEditor return View(Item('inputName', editor=CBEditor(choices=self._namespace_keys)), Item('_'), Item('IDkey', editor=CBEditor(choices=self._key_choices)), Item('histkey', editor=CBEditor(choices=self._key_choices)), Item('nbins'), Item('minval'), Item('maxval'), Item('_'), Item('outputName'), buttons=['OK'])
class SlidingWindowMAD(ModuleBase): """ Using a rolling window along the time (/ z) dimension, calculate the median-absolute deviation (MAD) Parameters ---------- series: PYME.IO.image.ImageStack time_window_size: int Size of window to use in rolling-median and standard deviation calculations Returns ------- output: PYME.IO.image.ImageStack MAD calculated within the rolling window. Note that the window size is kept constant, so output will be a shorter series than the input. Notes ----- Currently only set up for single-color data """ input = Input('input') time_window_size = Int(10) process_frames_individually = False output = Output('MAD') def execute(self, namespace): from scipy.stats import median_absolute_deviation series = namespace[self.input] steps = range(series.data.shape[2] - self.time_window_size) output = np.empty( (series.data.shape[0], series.data.shape[1], len(steps)), dtype=series.data[:, :, 0, 0].dtype) # only 1 color for now for ti in steps: output[:, :, ti] = median_absolute_deviation( series.data[:, :, ti:ti + self.time_window_size], scale=1, axis=2) out = image.ImageStack(data=output) out.mdh = MetaDataHandler.NestedClassMDHandler() try: out.mdh.copyEntriesFrom(series.mdh) except AttributeError: pass out.mdh['Analysis.FilterSpikes.TimeWindowSize'] = self.time_window_size namespace[self.output] = out
class InterpolateDrift(ModuleBase): """ Creates a spline interpolator from drift data. (``scipy.interpolate.UnivariateSpline``) Inputs ------ input_drift_raw : Tuple of arrays Drift measured from localization or image dataset. Outputs ------- output_drift_interpolator : Drift interpolator. Returns drift when called with frame number / time. output_drift_plot : Plot Plot of the original and interpolated drift. Parameters ---------- degree_of_spline : int Degree of the smoothing spline. smoothing_factor : float Smoothing factor. """ # input_dummy = Input('input') # breaks GUI without this??? # load_path = File() degree_of_spline = Int(3) # 1 for linear, 3 for cubic smoothing_factor = Float( -1) # 0 for no smoothing. set to negative for UnivariateSpline defulat input_drift_raw = Input('drift_raw') output_drift_interpolator = Output('drift_interpolator') output_drift_plot = Output('drift_plot') def execute(self, namespace): # data = np.load(self.load_path) # tIndex = data['tIndex'] # drift = data['drift'] # namespace[self.output_drift_raw] = (tIndex, drift) tIndex, drift = namespace[self.input_drift_raw] spl = interpolate_drift(tIndex, drift, self.degree_of_spline, self.smoothing_factor) namespace[self.output_drift_interpolator] = spl # # non essential, only for plotting out drift data namespace[self.output_drift_plot] = Plot( partial(generate_drift_plot, tIndex, drift, spl)) namespace[self.output_drift_plot].plot()
class CNNFilter(Filter): """ Use a previously trained Keras neural network to filter the data. Used for learnt de-noising and/or deconvolution. Runs prediction piecewise over the image with over-lapping ROIs and averages the prediction results. Notes ----- Keras and either Tensorflow or Theano must be installed set up for this module to work. This is not a default dependency of python-microscopy as the conda-installable versions don't have GPU support. """ model = FileOrURI('') step_size = Int(14) def _load_model(self): from keras.models import load_model if not getattr(self, '_model_name', None) == self.model: self._model_name = self.model with unifiedIO.local_or_temp_filename(self._model_name) as fn: self._model = load_model(fn) def applyFilter(self, data, chanNum, frNum, im): self._load_model() out = np.zeros(data.shape, 'f') _, kernel_size_x, kernel_size_y, _ = self._model.input_shape scale_factor = 1. / kernel_size_x * kernel_size_y / (float( self.step_size)**2) for i_x in range(0, out.shape[0] - kernel_size_x, self.step_size): for i_y in range(0, out.shape[1] - kernel_size_y, self.step_size): d_i = data[i_x:(i_x + kernel_size_x), i_y:(i_y + kernel_size_y)].reshape( 1, kernel_size_x, kernel_size_y, 1) p = self._model.predict(d_i).reshape(kernel_size_x, kernel_size_y) out[i_x:(i_x + kernel_size_x), i_y:(i_y + kernel_size_y)] += scale_factor * p return out def completeMetadata(self, im): im.mdh['CNNFilter.model'] = self.model
class PointFeatureBase(ModuleBase): """ common base class for feature extraction routines - implements normalisation and PCA routines """ outputColumnName = CStr('features') columnForEachFeature = Bool( False ) #if true, outputs a column for each feature - useful for visualising normalise = Bool(True) #subtract mean and divide by std. deviation PCA = Bool( True ) # reduce feature dimensionality by performing PCA - TODO - should this be a separate module and be chained instead? PCA_components = Int(3) # 0 = same dimensionality as features def _process_features(self, data, features): from PYME.IO import tabular out = tabular.MappingFilter(data) out.mdh = getattr(data, 'mdh', None) if self.normalise: features = features - features.mean(0)[None, :] features = features / features.std(0)[None, :] if self.PCA: from sklearn.decomposition import PCA pca = PCA(n_components=( self.PCA_components if self.PCA_components > 0 else None )).fit(features) features = pca.transform(features) out.pca = pca #save the pca object just in case we want to look at what the principle components are (this is hacky) out.addColumn(self.outputColumnName, features) if self.columnForEachFeature: for i in range(features.shape[1]): out.addColumn('feat_%d' % i, features[:, i]) return out
class PointFeaturesPairwiseDist(PointFeatureBase): """ Create a feature vector for each point in a point-cloud using a histogram of it's distances to all other points """ inputLocalisations = Input('localisations') outputName = Output('features') binWidth = Float(100.) # width of the bins in nm numBins = Int(20) #number of bins (starting at 0) threeD = Bool(True) normaliseRelativeDensity = Bool( False ) # divide by the sum of all radial bins. If not performed, the first principle component will likely be average density def execute(self, namespace): from PYME.Analysis.points import DistHist points = namespace[self.inputLocalisations] if self.threeD: x, y, z = points['x'], points['y'], points['z'] f = np.array([ DistHist.distanceHistogram3D(x[i], y[i], z[i], x, y, z, self.numBins, self.binWidth) for i in xrange(len(x)) ]) else: x, y = points['x'], points['y'] f = np.array([ DistHist.distanceHistogram(x[i], y[i], x, y, self.numBins, self.binWidth) for i in xrange(len(x)) ]) namespace[self.outputName] = self._process_features(points, f)
class DBSCANClustering2(ModuleBase): """ Performs DBSCAN clustering on input dictionary Parameters ---------- searchRadius: search radius for clustering minPtsForCore: number of points within SearchRadius required for a given point to be considered a core point Notes ----- See `sklearn.cluster.dbscan` for more details about the underlying algorithm and parameter meanings. """ import multiprocessing inputName = Input('filtered') columns = ListStr(['x', 'y', 'z']) searchRadius = Float() minClumpSize = Int() numberOfJobs = Int( max(multiprocessing.cpu_count() - 1, 1)) # this is a feature of the latest dbscan in scipy outputName = Output('dbscanClustered') def execute(self, namespace): from sklearn.cluster import dbscan inp = namespace[self.inputName] mapped = tabular.mappingFilter(inp) # Note that sklearn gives unclustered points label of -1, and first value starts at 0. try: core_samp, dbLabels = dbscan(np.vstack( [inp[k] for k in self.columns]).T, self.searchRadius, self.minClumpSize, n_jobs=self.numberOfJobs) multiproc = True except: core_samp, dbLabels = dbscan( np.vstack([inp[k] for k in self.columns]).T, self.searchRadius, self.minClumpSize) multiproc = False if multiproc: logger.info('using dbscan multiproc version') else: logger.info('falling back to dbscan single-threaded version') # shift dbscan labels up by one to match existing convention that a clumpID of 0 corresponds to unclumped mapped.addColumn('dbscanClumpID', dbLabels + 1) # propogate metadata, if present try: mapped.mdh = inp.mdh except AttributeError: pass namespace[self.outputName] = mapped @property def hide_in_overview(self): return ['columns']
class LoadDriftandInterp(ModuleBase): """ Loads drift data from file(s) and use them to create a spline interpolator (``scipy.interpolate.UnivariateSpline``). Inputs ------ input_dummy : None Blank input. Required to run correctly. Outputs ------- output_drift_interpolator : Drift interpolator. Returns drift when called with frame number / time. output_drift_plot : Plot Plot of the original and interpolated drift. Parameters ---------- load_paths : list of File List of files to load. degree_of_spline : int Degree of the smoothing spline. smoothing_factor : float Smoothing factor. """ input_dummy = Input('input') # breaks GUI without this??? # load_path = File() load_paths = List(File, [""], 1) degree_of_spline = Int(3) # 1 for linear, 3 for cubic smoothing_factor = Float( -1) # 0 for no smoothing. set to negative for UnivariateSpline defulat # input_drift_raw = Input('drift_raw') output_drift_interpolator = Output('drift_interpolator') output_drift_plot = Output('drift_plot') # output_drift_raw= Input('drift_raw') def execute(self, namespace): spl_array = list() t_min = np.inf t_max = 0 tIndexes = list() drifts = list() for fil in self.load_paths: data = np.load(fil) tIndex = data['tIndex'] t_min = min(t_min, tIndex[0]) t_max = max(t_max, tIndex[-1]) drift = data['drift'] tIndexes.append(tIndex) drifts.append(drift) spl = interpolate_drift(tIndex, drift, self.degree_of_spline, self.smoothing_factor) spl_array.append(spl) # print(len(spl_array)) # print(spl_array[0]) spl_array = zip(*spl_array) # print(len(spl_final)) # print(spl_final[0]) def spl_method(funcs, t): return np.sum([f(t) for f in funcs], axis=0) spl_combined = list() for spl in spl_array: # print(spl) # spl_combined.append(lambda x: np.sum([f(x) for f in spl], axis=0)) spl_combined.append(partial(spl_method, spl)) namespace[self.output_drift_interpolator] = spl_combined # non essential, only for plotting out drift data namespace[self.output_drift_plot] = Plot( partial(generate_drift_plot, tIndexes, drifts, spl_combined)) namespace[self.output_drift_plot].plot()
class ClusteringByLabel(ModuleBase): """ Parameters ---------- input_name : Input PYME.IO.ImageStack mask : Input PYME.IO.ImageStack. Optional mask to only calculate metrics Returns ------- output_name = Output Notes ----- """ input_name = Input('input') mask = Input('') excitation_start_frame = Int(10) output_vom = CStr('') output_mean_pre_excitation = CStr('') output_name = Output('cluster_metrics') def execute(self, namespace): series = namespace[self.input_name] # squeeze down from 4D data = series.data[:, :, :].squeeze() if self.mask == '': # not the most memory efficient, but make a mask logger.debug( 'No mask provided to ClusteringByLabel, analyzing full image') mask = np.ones((data.shape[0], data.shape[1]), int) else: mask = namespace[self.mask].data[:, :, :].squeeze() # toss any negative labels, as well as the zero label (per PYME clustering schema). labels = sorted(list(set(np.clip(np.unique(mask), 0, None)) - {0})) print(labels) n_labels = len(labels) # calculate the Variance_t over Mean_t var = np.var(data[:, :, self.excitation_start_frame:], axis=2) mean = np.mean(data[:, :, self.excitation_start_frame:], axis=2) variance_over_mean = var / mean if np.isnan(variance_over_mean).any(): logger.error('Variance over mean contains NaN, see %s' % series.filename) mean_pre_excitation = np.mean(data[:, :, :self.excitation_start_frame], axis=2) cluster_metric_mean = np.zeros(n_labels) mean_before_excitation = np.zeros(n_labels) for li in range(n_labels): # everything is 2D at this point label_mask = mask == labels[li] cluster_metric_mean[li] = np.mean(variance_over_mean[label_mask]) mean_before_excitation[li] = np.mean( mean_pre_excitation[label_mask]) res = tabular.DictSource({ 'variance_over_mean': cluster_metric_mean, 'mean_intensity_over_first_10_frames': mean_before_excitation, 'labels': np.array(labels) }) try: res.mdh = series.mdh except AttributeError: res.mdh = None namespace[self.output_name] = res if self.output_vom != '': namespace[self.output_vom] = image.ImageStack( data=variance_over_mean, mdh=res.mdh) if self.output_mean_pre_excitation != '': namespace[self.output_mean_pre_excitation] = image.ImageStack( data=mean_pre_excitation, mdh=res.mdh)
class PreprocessingFilter(CacheCleanupModule): """ Optional. Combines a few image processing operations. Only 3D. 1. Applies median filter for denoising. 2. Replaces out of range values to defined values. 3. Applies 2D Tukey filter to dampen potential edge artifacts. Inputs ------ input_name : ImageStack Outputs ------- output_name : ImageStack Parameters ---------- median_filter_size : int Median filter size (``scipy.ndimage.median_filter``). threshold_lower : Float Pixels at this value or lower are replaced. clip_to_lower : Float Pixels below the lower threshold are replaced by this value. threshold_upper : Float Pixels at this value or higher are replaced. clip_to_upper : Float Pixels above the upper threshold are replaced by this value. tukey_size : float Shape parameter for Tukey filter (``scipy.signal.tukey``). cache_clip : File Use file as disk cache if provided. """ input_name = Input('input') threshold_lower = Float(0) clip_to_lower = Float(0) threshold_upper = Float(65535) clip_to_upper = Float(0) median_filter_size = Int(3) tukey_size = Float(0.25) cache_clip = File("clip_cache.bin") output_name = Output('clipped_images') def _execute(self, namespace): self._start_time = time.time() ims = namespace[self.input_name] dtype = ims.data[:, :, 0].dtype # Somewhat arbitrary way to decide on chunk size chunk_size = 100000000 / ims.data.shape[0] / ims.data.shape[ 1] / dtype.itemsize chunk_size = max(1, chunk_size) chunk_size = int(chunk_size) # print chunk_size tukey_mask_x = signal.tukey(ims.data.shape[0], self.tukey_size) tukey_mask_y = signal.tukey(ims.data.shape[1], self.tukey_size) self._tukey_mask_2d = np.multiply( *np.meshgrid(tukey_mask_x, tukey_mask_y, indexing='ij'))[:, :, None] if self.cache_clip == "": raw_data = np.empty(tuple( np.asarray(ims.data.shape[:3], dtype=np.long)), dtype=dtype) else: raw_data = np.memmap(self.cache_clip, dtype=dtype, mode='w+', shape=tuple( np.asarray(ims.data.shape[:3], dtype=np.long))) progress = 0.2 * ims.data.shape[2] for f in np.arange(0, ims.data.shape[2], chunk_size): raw_data[:, :, f:f + chunk_size] = self.applyFilter( ims.data[:, :, f:f + chunk_size]) if (f + chunk_size >= progress): if isinstance(raw_data, np.memmap): raw_data.flush() progress += 0.2 * ims.data.shape[2] print("{:.2f} s. Completed clipping {} of {} total images.". format(time.time() - self._start_time, min(f + chunk_size, ims.data.shape[2]), ims.data.shape[2])) clipped_images = ImageStack(raw_data, mdh=ims.mdh) self.completeMetadata(clipped_images) namespace[self.output_name] = clipped_images def applyFilter(self, data): """ Performs the actual filtering here. """ if self.median_filter_size > 0: data = ndimage.median_filter(data, self.median_filter_size, mode='nearest') data[data >= self.threshold_upper] = self.clip_to_upper data[data <= self.threshold_lower] = self.clip_to_lower data -= self.clip_to_lower if self.tukey_size > 0: data = data * self._tukey_mask_2d return data def completeMetadata(self, im): im.mdh['Processing.Clipping.LowerBounds'] = self.threshold_lower im.mdh['Processing.Clipping.LowerSetValue'] = self.clip_to_lower im.mdh['Processing.Clipping.UpperBounds'] = self.threshold_upper im.mdh['Processing.Clipping.UpperSetValue'] = self.clip_to_upper im.mdh['Processing.Tukey.Size'] = self.tukey_size
class RCCDriftCorrectionBase(CacheCleanupModule): """ Performs drift correction using redundant cross-correlation from Wang et al. Optics Express 2014 22:13 (Bo Huang's RCC algorithm). Base class for other RCC recipes. Can take cached fft input (as filename, not an 'input'). Only output drift as tuple of time points, and drift amount. Currently not registered by itself since not very usefule. """ cache_fft = File("rcc_cache.bin") method = Enum(['RCC', 'MCC', 'DCC']) # redundant cross-corelation, mean cross-correlation, direct cross-correlation shift_max = Float(5) # nm corr_window = Int(5) multiprocessing = Bool() debug_cor_file = File() output_drift = Output('drift') output_drift_plot = Output('drift_plot') # if debug_cor_file not blank, filled with imagestack of cross correlation output_cross_cor = Output('cross_cor') def calc_corr_drift_from_ft_images(self, ft_images): n_steps = ft_images.shape[0] # Matrix equation coefficient matrix # Shape can be predetermined based on method if self.method == "DCC": coefs_size = n_steps - 1 elif self.corr_window > 0: coefs_size = n_steps * self.corr_window - self.corr_window * ( self.corr_window + 1) // 2 else: coefs_size = n_steps * (n_steps - 1) // 2 coefs = np.zeros((coefs_size, n_steps - 1)) shifts = np.zeros((coefs_size, 3)) counter = 0 ft_1_cache = list() ft_2_cache = list() autocor_shift_cache = list() # print self.debug_cor_file if not self.debug_cor_file == "": cc_file_shape = [ shifts.shape[0], ft_images.shape[1], ft_images.shape[2], (ft_images.shape[3] - 1) * 2 ] # flatten shortest dimension to reduce cross correlation to 2d images for easier debugging min_arg = min(enumerate(cc_file_shape[1:]), key=lambda x: x[1])[0] + 1 cc_file_shape.pop(min_arg) cc_file_args = (self.debug_cor_file, np.float, tuple(cc_file_shape)) cc_file = np.memmap(cc_file_args[0], dtype=cc_file_args[1], mode="w+", shape=cc_file_args[2]) # del cc_file cc_args = zip(range(shifts.shape[0]), (cc_file_args, ) * shifts.shape[0]) else: cc_args = (None, ) * shifts.shape[0] # For each ft image, calculate correlation for i in np.arange(0, n_steps - 1): if self.method == "DCC" and i > 0: break ft_1 = ft_images[i, :, :] autocor_shift = calc_shift(ft_1, ft_1) for j in np.arange(i + 1, n_steps): if (self.method != "DCC") and (self.corr_window > 0) and ( j - i > self.corr_window): break ft_2 = ft_images[j, :, :] coefs[counter, i:j] = 1 # if multiprocessing, use cache when defined if self.multiprocessing: # if reading ft_images from cache, replace ft_1 and ft_2 with their indices if not self.cache_fft == "": ft_1 = i ft_2 = j ft_1_cache.append(ft_1) ft_2_cache.append(ft_2) autocor_shift_cache.append(autocor_shift) else: shifts[counter, :] = calc_shift(ft_1, ft_2, autocor_shift, None, cc_args[counter]) if ((counter + 1) % max(coefs_size // 5, 1) == 0): print( "{:.2f} s. Completed calculating {} of {} total shifts." .format(time.time() - self._start_time, counter + 1, coefs_size)) counter += 1 if self.multiprocessing: args = zip( range(len(autocor_shift_cache)), ft_1_cache, ft_2_cache, autocor_shift_cache, len(ft_1_cache) * ((self.cache_fft, ft_images.dtype, ft_images.shape), ), cc_args) for i, (j, res) in enumerate( self._pool.imap_unordered(calc_shift_helper, args)): shifts[j, ] = res if ((i + 1) % max(coefs_size // 5, 1) == 0): print( "{:.2f} s. Completed calculating {} of {} total shifts." .format(time.time() - self._start_time, i + 1, coefs_size)) print("{:.2f} s. Finished calculating all shifts.".format( time.time() - self._start_time)) print("{:,} bytes".format(coefs.nbytes)) print("{:,} bytes".format(shifts.nbytes)) if not self.debug_cor_file == "": # move time axis for ImageStack cc_file = np.moveaxis(cc_file, 0, 2) self.trait_setq(**{"_cc_image": ImageStack(data=cc_file.copy())}) del cc_file else: self.trait_setq(**{"_cc_image": None}) assert (np.all(np.any( coefs, axis=1))), "Coefficient matrix filled less than expected." mask = np.where(~np.isnan(shifts).any(axis=1))[0] if len(mask) < shifts.shape[0]: print("Removed {} cross correlations due to bad/missing data?". format(shifts.shape[0] - len(mask))) coefs = coefs[mask, :] shifts = shifts[mask, :] assert (coefs.shape[0] > 0) and ( np.linalg.matrix_rank(coefs) == n_steps - 1), "Something went wrong with coefficient matrix. Not full rank." return shifts, coefs # shifts.shape[0] is n_steps - 1 def rcc( self, shift_max, t_shift, shifts, coefs, ): """ Should probably rename function. Takes cross correlation results and calculates shifts. """ print("{:.2f} s. About to start solving shifts array.".format( time.time() - self._start_time)) # Estimate drift drifts = np.matmul(np.linalg.pinv(coefs), shifts) # print(t_shift) # print(drifts) print("{:.2f} s. Done solving shifts array.".format(time.time() - self._start_time)) if self.method == "RCC": # Calculate residual errors residuals = np.matmul(coefs, drifts) - shifts residuals_dist = np.linalg.norm(residuals, axis=1) # Sort and mask residual errors residuals_arg = np.argsort(-residuals_dist) residuals_arg = residuals_arg[ residuals_dist[residuals_arg] > shift_max] # Remove coefs rows # Descending from largest residuals to small # Only if matrix remains full rank coefs_temp = np.empty_like(coefs) counter = 0 for i, index in enumerate(residuals_arg): coefs_temp[:] = coefs coefs_temp[index, :] = 0 if np.linalg.matrix_rank(coefs_temp) == coefs.shape[1]: coefs[:] = coefs_temp # print("index {} with residual of {} removed".format(index, residuals_dist[index])) counter += 1 else: print( "Could not remove all residuals over shift_max threshold." ) break print("removed {} in total".format(counter)) # Estimate drift again drifts = np.matmul(np.linalg.pinv(coefs), shifts) print("{:.2f} s. RCC completed. Repeated solving shifts array.". format(time.time() - self._start_time)) # pad with 0 drift for first time point drifts = np.pad(drifts, [[1, 0], [0, 0]], 'constant', constant_values=0) return t_shift, drifts def _execute(self, namespace): # dervied versions of RCC need to override this method # 'execute' of this RCC base class is not throughly tested as its use is probably quite limited. # from PYME.util import mProfile self._start_time = time.time() print("Starting drift correction module.") if self.multiprocessing: proccess_count = np.clip(multiprocessing.cpu_count() - 1, 1, None) self._pool = multiprocessing.Pool(processes=proccess_count) # mProfile.profileOn(['localisations.py']) drift_res = self.calc_corr_drift_from_ft_images(self.cache_fft) t_shift, shifts = self.rcc(self.shift_max, *drift_res) # mProfile.profileOff() # mProfile.report() if self.multiprocessing: self._pool.close() self._pool.join() # convert frame-to-frame drift to drift from origin shifts = np.cumsum(shifts, 0) namespace[self.output_drift] = t_shift, shifts
class OctreeRenderLayer(TriangleRenderLayer): """ Layer for viewing octrees. Takes in an octree, splits the faces into triangles, and then uses the rendering engines from PYME.LMVis.layers.triangle_mesh. """ # Additional properties (the rest are inherited from TriangleRenderLayer) depth = Int(3, desc='Depth at which to render Octree. Set to -1 for dynamic depth rendering.') density = Float(0.0, desc='Minimum density of octree node to display.') min_points = Int(10, desc='Number of points/node to truncate octree at') def __init__(self, pipeline, method='wireframe', dsname='', context=None, **kwargs): TriangleRenderLayer.__init__(self, pipeline, method, dsname, context, **kwargs) self.on_trait_change(self.update, 'depth') self.on_trait_change(self.update, 'density') self.on_trait_change(self.update, 'min_points') @property def _ds_class(self): from PYME.experimental import octree return (octree.Octree, octree.PyOctree) def update_from_datasource(self, ds): """ Opens an octree. Subdivides the faces into triangles. Feeds the triangle points/normals to update_data. Parameters ---------- ds : Octree (see PYME.experimental.octree) Returns ------- None """ nodes = ds._nodes[ds._nodes[ds._nodes['parent']]['nPoints'] >= float(self.min_points)] if self.depth > 0: # Grab the nodes at the specified depth nodes = nodes[nodes['depth'] == self.depth] box_sizes = np.ones((nodes.shape[0], 3))*ds.box_size(self.depth) node_density = 1.*nodes['nPoints']/np.prod(box_sizes,axis=1) nodes = nodes[node_density >= self.density] box_sizes = np.ones((nodes.shape[0], 3))*ds.box_size(self.depth) alpha = nodes['nPoints']/box_sizes[:,0] elif self.depth == 0: # plot all bins nodes = nodes[nodes['nPoints'] >= 1] box_sizes = np.vstack(ds.box_size(nodes['depth'])).T alpha = nodes['nPoints'] * ((2 ** nodes['depth'])**3) else: # Plot leaf nodes nodes = nodes[(np.sum(nodes['children'],axis=1) == 0)&(nodes['depth'] > 0)] box_sizes = np.vstack(ds.box_size(nodes['depth'])).T alpha = nodes['nPoints']*((2.0**nodes['depth']))**3 if len(nodes) > 0: c = nodes['centre'] # center shifts = (box_sizes[:,None]*OCT_SHIFT[None,:])*0.5 v = (c[:,None,:] + shifts) # # z # ^ # | # v4 ----------v6 # /| /| # / | / | # v5----------v7 | # | | c | | # | v0---------|-v2 # | / | / # v1-----------v3---> y # / # x # # Now note that the counterclockwise triangles (when viewed straight-on) formed along the faces of the cube are: # # v0 v2 v1 # v0 v1 v5 # v0 v5 v4 # v0 v6 v2 # v0 v4 v6 # v1 v2 v3 # v1 v3 v7 # v1 v7 v5 # v2 v6 v7 # v2 v7 v3 # v4 v5 v6 # v5 v7 v6 # Counterclockwise triangles (when viewed straight-on) formed along # the faces of an octree box t0 = np.vstack(v[:,[0,0,0,0,0,1,1,1,2,2,4,5],:]) t1 = np.vstack(v[:,[2,1,5,6,4,2,3,7,6,7,5,7],:]) t2 = np.vstack(v[:,[1,5,4,2,6,3,7,5,7,3,6,6],:]) x, y, z = np.hstack([t0,t1,t2]).reshape(-1, 3).T # positions # Now we create the normal as the cross product tn = np.cross((t2-t1),(t0-t1)) # We copy the normals 3 times per triangle to get 3x(3N) normals to match the vertices shape xn, yn, zn = np.repeat(tn.T, 3, axis=1) # normals # Color is fixed constnat for octree c = np.ones(len(x)) clim = [0, 1] alpha = self.alpha*alpha/alpha.max() alpha = (alpha[None,:]*np.ones(12)[:,None]) alpha = np.repeat(alpha.ravel(), 3) print('Octree scaled alpha range: %g, %g' % (alpha.min(), alpha.max())) cmap = getattr(cm, self.cmap) # Do we have coordinates? Concatenate into vertices. if x is not None and y is not None and z is not None: vertices = np.vstack((x.ravel(), y.ravel(), z.ravel())) self._vertices = vertices.T.ravel().reshape(len(x.ravel()), 3) if not xn is None: self._normals = np.vstack((xn.ravel(), yn.ravel(), zn.ravel())).T.ravel().reshape(len(x.ravel()), 3) else: self._normals = -0.69 * np.ones(self._vertices.shape) self._bbox = np.array([x.min(), y.min(), z.min(), x.max(), y.max(), z.max()]) else: self._bbox = None if clim is not None and c is not None and cmap is not None: cs_ = ((c - clim[0]) / (clim[1] - clim[0])) cs = cmap(cs_) if self.method in ['flat', 'tessel']: alpha = cs_ * alpha cs[:, 3] = alpha if self.method == 'tessel': cs = np.power(cs, 0.333) self._colors = cs.ravel().reshape(len(c), 4) else: # cs = None if not self._vertices is None: self._colors = np.ones((self._vertices.shape[0], 4), 'f') print('Colors: {}'.format(self._colors)) self._alpha = alpha self._color_map = cmap self._color_limit = clim else: print('No nodes for density {0}, depth {1}'.format(self.density, self.depth)) @property def default_view(self): from traitsui.api import View, Item, Group, InstanceEditor, EnumEditor from PYME.ui.custom_traits_editors import HistLimitsEditor, CBEditor return View([#Group([ Item('dsname', label='Data', editor=EnumEditor(name='_datasource_choices')), Item('method'), Item('depth'),Item('min_points'), #Item('vertexColour', editor=EnumEditor(name='_datasource_keys'), label='Colour'), #Group([Item('clim', editor=HistLimitsEditor(data=self._get_cdata), show_label=False), ]), Group([Item('cmap', label='LUT'), Item('alpha')])], )
class OctreeRenderLayer(TriangleRenderLayer): """ Layer for viewing octrees. Takes in an octree, splits the faces into triangles, and then uses the rendering engines from PYME.LMVis.layers.triangle_mesh. """ # properties to show in the GUI. Note that we also inherit 'visible' from BaseLayer #vertexColour = CStr('', desc='Name of variable used to colour our points') #cmap = Enum(*cm.cmapnames, default='gist_rainbow', desc='Name of colourmap used to colour faces') #clim = ListFloat([0, 1], desc='How our variable should be scaled prior to colour mapping') #alpha = Float(1.0, desc='Face tranparency') depth = Int( 3, desc= 'Depth at which to render Octree. Set to -1 for dynamic depth rendering.' ) density = Float(0.0, desc='Minimum density of octree node to display.') min_points = Int(10, desc='Number of points/node to truncate octree at') #method = Enum(*ENGINES.keys(), desc='Method used to display faces') def __init__(self, pipeline, method='wireframe', dsname='', **kwargs): TriangleRenderLayer.__init__(self, pipeline, method, dsname, **kwargs) self.on_trait_change(self.update, 'depth') self.on_trait_change(self.update, 'density') self.on_trait_change(self.update, 'min_points') @property def _ds_class(self): from PYME.experimental import octree return (octree.Octree, octree.PyOctree) def update_from_datasource(self, ds): """ Opens an octree. Subdivides the faces into triangles. Feeds the triangle points/normals to update_data. Parameters ---------- ds : Octree (see PYME.experimental.octree) Returns ------- None """ nodes = ds._nodes[ds._nodes[ds._nodes['parent']]['nPoints'] >= float( self.min_points)] if self.depth > 0: # Grab the nodes at the specified depth nodes = nodes[nodes['depth'] == self.depth] box_sizes = np.ones((nodes.shape[0], 3)) * ds.box_size(self.depth) node_density = 1. * nodes['nPoints'] / np.prod(box_sizes, axis=1) nodes = nodes[node_density >= self.density] box_sizes = np.ones((nodes.shape[0], 3)) * ds.box_size(self.depth) alpha = nodes['nPoints'] / box_sizes[:, 0] elif self.depth == 0: # plot all bins nodes = nodes[nodes['nPoints'] >= 1] box_sizes = np.vstack(ds.box_size(nodes['depth'])).T alpha = nodes['nPoints'] * ((2**nodes['depth'])**3) else: # Follow the nodes until we reach a terminating node, then append this node to our list of nodes to render # Start at the 0th node # children = ds._nodes[0]['children'].tolist() # node_indices = [] # box_sizes = [] # # Do this until we've looked at the whole octree (the list of children is empty) # while children: # # Check the child # node_index = children.pop() # curr_node = ds._nodes[node_index] # # Is this a terminating node? # if np.any(curr_node['children']): # # It's not, so we'll add the children to the list # new_children = curr_node['children'][curr_node['children'] > 0] # children.extend(new_children) # else: # # Terminating node! We want to render this # node_indices.append(node_index) # box_sizes.append(ds.box_size(curr_node['depth'])) # We've followed the octree to the end, return the nodes and box sizes #nodes = ds._nodes[node_indices] #box_sizes = np.array(box_sizes) nodes = nodes[np.sum(nodes['children']) == 0] box_sizes = np.vstack(ds.box_size(nodes['depth'])).T alpha = nodes['nPoints'] * ((2.0**nodes['depth']))**3 # First we need the vertices of the cube. We find them from the center c provided and the box size (lx, ly, lz) # provided by the octree: if len(nodes) > 0: c = nodes['centre'] # center v0 = c + box_sizes * -1 / 2 # v0 = c - lx/2 - ly/2 - lz/2 v1 = c + box_sizes * [-1, -1, 1] / 2 # v1 = c - lx/2 - ly/2 + lz/2 v2 = c + box_sizes * [-1, 1, -1] / 2 # v2 = c - lx/2 + ly/2 - lz/2 v3 = c + box_sizes * [-1, 1, 1] / 2 # v3 = c - lx/2 + ly/2 + lz/2 v4 = c + box_sizes * [1, -1, -1] / 2 # v4 = c + lx/2 - ly/2 - lz/2 v5 = c + box_sizes * [1, -1, 1] / 2 # v5 = c + lx/2 - ly/2 + lz/2 v6 = c + box_sizes * [1, 1, -1] / 2 # v6 = c + lx/2 + ly/2 - lz/2 v7 = c + box_sizes / 2 # v7 = c + lx/2 + ly/2 + lz/2 # # z # ^ # | # v1 ----------v3 # /| /| # / | / | # v5----------v7 | # | | c | | # | v0---------|-v2 # | / | / # v4-----------v6---> y # / # x # # Now note that the counterclockwise triangles (when viewed straight-on) formed along the faces of the cube are: # # v0 v2 v6 # v0 v4 v5 # v1 v3 v2 # v2 v0 v1 # v3 v1 v5 # v3 v7 v6 # v4 v6 v7 # v5 v1 v0 # v5 v7 v3 # v6 v2 v3 # v6 v4 v0 # v7 v5 v4 # Concatenate vertices, interleave, restore to 3x(3N) points (3xN triangles), # and assign the points to x, y, z vectors triangle_v0 = np.vstack( (v0, v0, v1, v2, v3, v3, v4, v5, v5, v6, v6, v7)) triangle_v1 = np.vstack( (v2, v4, v3, v0, v1, v7, v6, v1, v7, v2, v4, v5)) triangle_v2 = np.vstack( (v6, v5, v2, v1, v5, v6, v7, v0, v3, v3, v0, v4)) x, y, z = np.hstack( (triangle_v0, triangle_v1, triangle_v2)).reshape(-1, 3).T # Now we create the normal as the cross product (triangle_v2 - triangle_v1) x (triangle_v0 - triangle_v1) triangle_normals = np.cross((triangle_v2 - triangle_v1), (triangle_v0 - triangle_v1)) # We copy the normals 3 times per triangle to get 3x(3N) normals to match the vertices shape xn, yn, zn = np.repeat(triangle_normals.T, 3, axis=1) alpha = self.alpha * alpha / alpha.max() alpha = (alpha[None, :] * np.ones(12)[:, None]) alpha = np.repeat(alpha.ravel(), 3) print('Octree scaled alpha range: %g, %g' % (alpha.min(), alpha.max())) # Pass the restructured data to update_data self.update_data(x, y, z, cmap=getattr(cm, self.cmap), clim=self.clim, alpha=alpha, xn=xn, yn=yn, zn=zn) else: print('No nodes for density {0}, depth {1}'.format( self.density, self.depth)) @property def default_view(self): from traitsui.api import View, Item, Group, InstanceEditor, EnumEditor from PYME.ui.custom_traits_editors import HistLimitsEditor, CBEditor return View( [ #Group([ Item('dsname', label='Data', editor=EnumEditor(name='_datasource_choices')), Item('method'), Item('depth'), Item('min_points'), #Item('vertexColour', editor=EnumEditor(name='_datasource_keys'), label='Colour'), #Group([Item('clim', editor=HistLimitsEditor(data=self._get_cdata), show_label=False), ]), Group([Item('cmap', label='LUT'), Item('alpha')]) ], )
class ImageRenderLayer(EngineLayer): """ Layer for viewing images. """ # properties to show in the GUI. Note that we also inherit 'visible' from BaseLayer cmap = Enum(*cm.cmapnames, default='gray', desc='Name of colourmap used to colour faces') clim = ListFloat([0, 1], desc='How our data should be scaled prior to colour mapping') alpha = Float(1.0, desc='Tranparency') method = Enum(*ENGINES.keys(), desc='Method used to display image') dsname = CStr('output', desc='Name of the datasource within the pipeline to use as an image') channel = Int(0) slice = Int(0) z_pos = Float(0) _datasource_choices = List() _datasource_keys = List() def __init__(self, pipeline, method='image', dsname='', display_opts=None, context=None, **kwargs): EngineLayer.__init__(self, context=context, **kwargs) self._pipeline = pipeline self.engine = None self.cmap = 'gray' self._bbox = None self._do = display_opts #a dh5view display_options instance - if provided, this over-rides the the clim, cmap properties self._im_key = None # define a signal so that people can be notified when we are updated (currently used to force a redraw when # parameters change) self.on_update = dispatch.Signal() # define responses to changes in various traits #self.on_trait_change(self._update, 'vertexColour') self.on_trait_change(lambda: self.on_update.send(self), 'visible') self.on_trait_change(self.update, 'cmap, clim, alpha, dsname') self.on_trait_change(self._set_method, 'method') # update any of our traits which were passed as command line arguments self.set(**kwargs) # update datasource and method self.dsname = dsname if self.method == method: #make sure we still call _set_method even if we start with the default method self._set_method() else: self.method = method # if we were given a pipeline, connect ourselves to the onRebuild signal so that we can automatically update # ourselves if (not self._pipeline is None) and hasattr(pipeline, 'onRebuild'): self._pipeline.onRebuild.connect(self.update) @property def datasource(self): """ Return the datasource we are connected to (does not go through the pipeline for triangles_mesh). """ try: return self._pipeline.get_layer_data(self.dsname) except AttributeError: #fallback if pipeline is a dictionary return self._pipeline[self.dsname] #return self.datasource @property def _ds_class(self): # from PYME.experimental import triangle_mesh from PYME.IO import image return image.ImageStack def _set_method(self): self.engine = ENGINES[self.method](self._context) self.update() # def _update(self, *args, **kwargs): # #pass # cdata = self._get_cdata() # self.clim = [float(cdata.min()), float(cdata.max())] # self.update(*args, **kwargs) def update(self, *args, **kwargs): try: self._datasource_choices = [k for k, v in self._pipeline.dataSources.items() if isinstance(v, self._ds_class)] except AttributeError: self._datasource_choices = [k for k, v in self._pipeline.items() if isinstance(v, self._ds_class)] if not (self.engine is None or self.datasource is None): print('lw update') self.update_from_datasource(self.datasource) self.on_update.send(self) @property def bbox(self): return self._bbox def sync_to_display_opts(self, do=None): if (do is None): if not (self._do is None): do = self._do else: return o = do.Offs[self.channel] g = do.Gains[self.channel] clim = [o, o + 1.0 / g] cmap = do.cmaps[self.channel].name visible = do.show[self.channel] self.set(clim=clim, cmap=cmap, visible=visible) def update_from_datasource(self, ds): """ Parameters ---------- ds : PYME.IO.image.ImageStack object Returns ------- None """ #if self._do is not None: # Let display options (if provied) over-ride our settings (TODO - is this the right way to do this?) # o = self._do.Offs[self.channel] # g = self._do.Gains[self.channel] # clim = [o, o + 1.0/g] #self.clim = clim # cmap = self._do.cmaps[self.channel] #self.visible = self._do.show[self.channel] #else: clim = self.clim cmap = getattr(cm, self.cmap) alpha = float(self.alpha) c0, c1 = clim im_key = (self.dsname, self.slice, self.channel) if not self._im_key == im_key: self._im_key = im_key self._im = ds.data[:,:,self.slice, self.channel].astype('f4')# - c0)/(c1-c0) x0, y0, x1, y1, _, _ = ds.imgBounds.bounds self._bbox = np.array([x0, y0, 0, x1, y1, 0]) self._bounds = [x0, y0, x1, y1] self._alpha = alpha self._color_map = cmap self._color_limit = clim def get_color_map(self): return self._color_map @property def colour_map(self): return self._color_map def get_color_limit(self): return self._color_limit def _get_cdata(self): return self._im.ravel()[::20] @property def default_view(self): from traitsui.api import View, Item, Group, InstanceEditor, EnumEditor from PYME.ui.custom_traits_editors import HistLimitsEditor, CBEditor return View([Group([Item('dsname', label='Data', editor=EnumEditor(name='_datasource_choices')), ]), #Item('method'), Group([Item('clim', editor=HistLimitsEditor(data=self._get_cdata), show_label=False), ]), Group([Item('cmap', label='LUT'), Item('alpha', visible_when='method in ["flat", "tessel"]') ]) ], ) # buttons=['OK', 'Cancel']) def default_traits_view(self): return self.default_view
class FilterSpikes(ModuleBase): """ Using a rolling window along the time (/ z) dimension, identify spikes which are greatly above the median within that window and remove them by replacing the value with the median. Parameters ---------- series: PYME.IO.image.ImageStack time_window_size: int Size of window to use in rolling-median and standard deviation calculations threshold_factor: float Multiplicative factor used to set the spike threshold, which is threshold * median-absolute deviation + median, all calculated within the window for an individual x, y, pixel. threshold_change: float Absolute change required in a single time-step for a spike candidate to be considered a spike Returns ------- output: PYME.IO.image.ImageStack Spike-filtered copy of the input series Notes ----- Currently only set up for single-color data """ input = Input('input') time_window_size = Int(10) threshold_factor = Float(5) threshold_change = Float(370) process_frames_individually = False output = Output('filtered') def execute(self, namespace): from scipy.stats import median_absolute_deviation series = namespace[self.input] diff = np.diff(series.data[:, :, :, 0]).squeeze() over_jump_threshold = np.zeros(series.data.shape[:-1], dtype=bool) over_jump_threshold[:, :, 1:] = diff > self.threshold_change output = np.copy(series.data[:, :, :, 0].squeeze()) # only 1 color for now for ti in range(series.data.shape[2] - self.time_window_size): data = output[:, :, ti:ti + self.time_window_size] median = np.median(data, axis=2) spikes = np.logical_and( data > (self.threshold_factor * median_absolute_deviation(data, scale=1, axis=2) + median)[:, :, None], over_jump_threshold[:, :, ti:ti + self.time_window_size]) spike_locs = np.nonzero(spikes) output[spike_locs[0], spike_locs[1], spike_locs[2] + ti] = median[spike_locs[0], spike_locs[1]] out = image.ImageStack(data=output) out.mdh = MetaDataHandler.NestedClassMDHandler() try: out.mdh.copyEntriesFrom(series.mdh) except AttributeError: pass out.mdh[ 'Analysis.FilterSpikes.ThresholdFactor'] = self.threshold_factor out.mdh[ 'Analysis.FilterSpikes.ThresholdChange'] = self.threshold_change out.mdh['Analysis.FilterSpikes.TimeWindowSize'] = self.time_window_size namespace[self.output] = out
class FiducialTrack(ModuleBase): """ Extract average fiducial track from input pipeline Parameters ---------- radiusMultiplier: this number is multiplied with error_x to obtain search radius for clustering timeWindow: the window along the time dimension used for clustering filterScale: the size of the filter kernel used to smooth the resulting average fiducial track filterMethod: enumrated choice of filter methods for smoothing operation (Gaussian, Median or Uniform kernel) Notes ----- Output is a new pipeline with added fiducial_x, fiducial_y columns """ import PYMEcs.Analysis.trackFiducials as tfs inputName = Input('filtered') radiusMultiplier = Float(5.0) timeWindow = Int(25) filterScale = Float(11) filterMethod = Enum(tfs.FILTER_FUNCS.keys()) clumpMinSize = Int(50) singleFiducial = Bool(True) outputName = Output('fiducialAdded') def execute(self, namespace): import PYMEcs.Analysis.trackFiducials as tfs inp = namespace[self.inputName] mapped = tabular.mappingFilter(inp) if self.singleFiducial: # if all data is from a single fiducial we do not need to align # we then avoid problems with incomplete tracks giving rise to offsets between # fiducial track fragments align = False else: align = True t, x, y, z, isFiducial = tfs.extractTrajectoriesClump( inp, clumpRadiusVar='error_x', clumpRadiusMultiplier=self.radiusMultiplier, timeWindow=self.timeWindow, clumpMinSize=self.clumpMinSize, align=align) rawtracks = (t, x, y, z) tracks = tfs.AverageTrack(inp, rawtracks, filter=self.filterMethod, filterScale=self.filterScale, align=align) # add tracks for all calculated dims to output for dim in tracks.keys(): mapped.addColumn('fiducial_%s' % dim, tracks[dim]) mapped.addColumn('isFiducial', isFiducial) # propogate metadata, if present try: mapped.mdh = inp.mdh except AttributeError: pass namespace[self.outputName] = mapped @property def hide_in_overview(self): return ['columns']
class RGBImageUpload(ImageUpload): """ Create RGB (png) image and upload it to an OMERO server, optionally attaching localization files. Parameters ---------- input_image : str name of image in the recipe namespace to upload input_localization_attachments : dict maps tabular types (keys) to attachment filenames (values). Tabular's will be saved as '.hdf' files and attached to the image filePattern : str pattern to determine name of image on OMERO server omero_dataset : str name of OMERO dataset to add the image to. If the dataset does not already exist it will be created. omero_project : str name of OMERO project to link the dataset to. If the project does not already exist it will be created. zoom : float how large to zoom the image scaling : str how to scale the intensity - one of 'min-max' or 'percentile' scaling_factor: float `percentile` scaling only - which percentile to use colorblind_friendly : bool Use cyan, magenta, and yellow rather than RGB. True, by default. Notes ----- OMERO server address and user login information must be stored in the user PYME config directory under plugins/config/pyme-omero, e.g. /Users/Andrew/.PYME/plugins/config/pyme-omero. The file should be yaml formatted with the following keys: user password address port [optional] The project/image/dataset will be owned by the user set in the yaml file. """ filePattern = '{file_stub}.png' scaling = Enum(['min-max', 'percentile']) scaling_factor = Float(0.95) zoom = Int(1) colorblind_friendly = Bool(True) def _save(self, image, path): from PIL import Image from PYME.IO.rgb_image import image_to_rgb, image_to_cmy if (self.colorblind_friendly and (image.data.shape[3] != 1)): im = image_to_cmy(image, zoom=self.zoom, scaling=self.scaling, scaling_factor=self.scaling_factor) else: im = image_to_rgb(image, zoom=self.zoom, scaling=self.scaling, scaling_factor=self.scaling_factor) rgb = Image.fromarray(im, mode='RGB') rgb.save(path)
class HDBSCANClustering(ModuleBase): """ Performs HDBSCAN clustering on input dictionary Parameters ---------- minPtsForCore: The minimum size of clusters. Technically the only required parameter. searchRadius: Extract DBSCAN clustering based on search radius. Skipped if 0 or None. Notes ----- See https://github.com/scikit-learn-contrib/hdbscan Lots of other parameters not mapped. """ input_name = Input('filtered') # input_vert = Input('vert') columns = ListStr(['x', 'y']) search_radius = Float() min_clump_size = Int(100) clump_column_name = CStr('hdbscan_id') clump_prob_column_name = CStr('hdbscan_prob') clump_dbscan_column_name = CStr('dbscan_id') output_name = Output('hdbscan_clustered') def execute(self, namespace): # print('testing showpoints again') # print(namespace['showplots']) inp = namespace[self.input_name] mapped = tabular.mappingFilter(inp) # vert_data = namespace[self.input_vert] import hdbscan clusterer = hdbscan.HDBSCAN(min_cluster_size=self.min_clump_size) clusterer.fit(np.vstack([inp[k] for k in self.columns]).T) # Note that hdbscan gives unclustered points label of -1, and first value starts at 0. # shift hdbscan labels up by one to match existing convention that a clumpID of 0 corresponds to unclumped mapped.addColumn(str(self.clump_column_name), clusterer.labels_ + 1) mapped.addColumn(str(self.clump_prob_column_name), clusterer.probabilities_) if not self.search_radius is None and self.search_radius > 0: #Extract dbscan clustering from hdbscan clusterer dbscan = clusterer.single_linkage_tree_.get_clusters( self.search_radius, self.min_clump_size) # shift dbscan labels up by one to match existing convention that a clumpID of 0 corresponds to unclumped mapped.addColumn(str(self.clump_dbscan_column_name), dbscan + 1) # propogate metadata, if present try: mapped.mdh = inp.mdh print('testing for mdh') except AttributeError: pass namespace[self.output_name] = mapped print('finished clustering')
class DetectPSF(ModuleBase): """ Detect PSF based on diff of gaussian Image dims in X, Y, Z, C where C are processed independently. Returns list of (X, Y, Z) per C """ inputName = Input('input') min_sigma = Float(1.0) max_sigma = Float(3.0) sigma_ratio = Float(1.6) percent_threshold = Float(0.1) overlap = Float(0.5) exclude_border = Int(50) ignore_z = Bool(True) output_pos = Output('psf_pos') # output_img = Output('output') def execute(self, namespace): ims = namespace[self.inputName] pixel_size = ims.mdh['voxelsize.x'] pos = list() counts = ims.data.shape[3] for c in np.arange(counts): mean_project = ims.data[:,:,:,c].mean(2).squeeze() mean_project[mean_project==2**16-1] = 200 mean_project -= mean_project.min() mean_project /= mean_project.max() # if skimage is new enough to support exclude_border #blobs = feature.blob_dog(mean_project, self.min_sigma / pixel_size, self.max_sigma / pixel_size, overlap=self.overlap, threshold=self.percent_threshold*mean_project.max(), exclude_border=self.exclude_border) #otherwise: blobs = feature.blob_dog(mean_project, self.min_sigma / pixel_size, self.max_sigma / pixel_size, overlap=self.overlap, threshold=self.percent_threshold*mean_project.max()) edge_mask = (blobs[:, 0] > self.exclude_border) & (blobs[:, 0] < mean_project.shape[0] - self.exclude_border) edge_mask &= (blobs[:, 1] > self.exclude_border) & (blobs[:, 1] < mean_project.shape[1] - self.exclude_border) blobs = blobs[edge_mask] # is list of x, y, sig if self.ignore_z: blobs = np.insert(blobs, 2, ims.data.shape[2]//2, axis=1) else: raise Exception("z centering not yet implemented") blobs = blobs.astype(np.int) # print blobs pos.append(blobs) namespace[self.output_pos] = pos if True: try: # from matplotlib import pyplot fig, axes = pyplot.subplots(1, counts, figsize=(4*counts, 3), squeeze=False) for c in np.arange(counts): mean_project = ims.data[:,:,:,c].mean(2).squeeze() mean_project[mean_project==2**16-1] = 200 axes[0, c].imshow(mean_project) axes[0, c].set_axis_off() for x, y, z, sig in pos[c]: cir = pyplot.Circle((y, x), sig, color='red', linewidth=2, fill=False) axes[0, c].add_patch(cir) except Exception as e: print e
class CropPSF(ModuleBase): """ Crops out PSF based on positions given. Built-in filter by flattened index. Built-in filter for removing multiple peaked data. Filters work on flatten X, Y images Stacked in the 4 dimension """ inputName = Input('input') input_pos = Input('psf_pos') ignore_pos = List(Int, []) threshold_reject = Float(0.5) com_reject = Float(2.0) half_roi_x = Int(20) half_roi_y = Int(20) half_roi_z = Int(60) output_images = Output('psf_cropped') def execute(self, namespace): ims = namespace[self.inputName] psf_pos = namespace[self.input_pos] res = np.zeros((self.half_roi_x*2+1, self.half_roi_y*2+1, self.half_roi_z*2+1, sum([ar.shape[0] for ar in psf_pos]))) mask = np.ones(res.shape[3], dtype=bool) counter = 0 for c in np.arange(ims.data.shape[3]): for i in np.arange(len(psf_pos[c])): # print psf_pos[c][i][:3] x, y, z = psf_pos[c][i][:3] x_slice = slice(x-self.half_roi_x, x+self.half_roi_x+1) y_slice = slice(y-self.half_roi_y, y+self.half_roi_y+1) z_slice = slice(z-self.half_roi_z, z+self.half_roi_z+1) res[:, :, :, counter] = ims.data[x_slice, y_slice, z_slice, c].squeeze() crop_flatten = res[:, :, :, counter].mean(2) failed = False labeled_image, labeled_counts = ndimage.label(crop_flatten > crop_flatten.max() * self.threshold_reject) if labeled_counts > 1: failed = True else: com = np.asarray(ndimage.center_of_mass(crop_flatten, labeled_image, 1)) img_center = np.asarray([(s-1)*0.5 for s in labeled_image.shape]) dist = np.linalg.norm(com - img_center) # print(com, img_center, dist) if dist > self.com_reject: failed = True if failed and counter not in self.ignore_pos: self.ignore_pos.append(counter) counter += 1 # To do: add metadata # mdh['ImageType=']='PSF' print "images ignore: {}".format(self.ignore_pos) mask[self.ignore_pos] = False new_mdh = None try: new_mdh = MetaDataHandler.NestedClassMDHandler(ims.mdh) new_mdh["ImageType"] = 'PSF' if not "PSFExtraction.SourceFilenames" in new_mdh.keys(): new_mdh["PSFExtraction.SourceFilenames"] = ims.filename except Exception as e: print(e) namespace[self.output_images] = ImageStack(data=res[:,:,:,mask], mdh=new_mdh)
class ShiftImage(CacheCleanupModule): """ Performs FT based image shift. Only shift in 2D. Inputs ------ input_image : ImageStack Images with drift. input_drift_interpolator : Returns drift when called with frame number / time. Outputs ------- outputName : ImageStack Parameters ---------- padding_multipler : Int Padding (as multiple of image size) added to the image before shifting to avoid artifacts. cache_image : File Use file as disk cache if provided. """ input_image = Input('input') # input_shift = Input('drift') input_drift_interpolator = Input('drift_interpolator') padding_multipler = Int(1) # ft_cache = File("ft_images.bin") cache_image = File("shifted_image.bin") # image_cache_2 = File("rcc_shifted_image_2.bin") outputName = Output('drift_corrected_image') def _execute(self, namespace): self._start_time = time.time() # try: ## del self._ft_images # del self.image_cache # except: # pass ims = namespace[self.input_image] t_out = np.arange(ims.data.shape[2], dtype=np.float) if 'recipe.binning' in ims.mdh.keys(): t_out *= ims.mdh['recipe.binning'][2] t_out += 0.5 * ims.mdh['recipe.binning'][2] # print t_out dx = namespace[self.input_drift_interpolator][0](t_out) dy = namespace[self.input_drift_interpolator][1](t_out) shifted_images = self.shift_images(ims, np.stack([dx, dy], 1), ims.mdh) namespace[self.outputName] = ImageStack(shifted_images, titleStub=self.outputName, mdh=ims.mdh) def shift_images(self, ims, shifts, mdh): padding = np.stack((ims.data.shape[:2], ) * 2, -1) #2d only padding *= self.padding_multipler padded_image_shape = np.asarray(ims.data.shape[:2], dtype=np.long) + padding.sum((1)) dtype = ims.data[:, :, 0].dtype padded_image = np.zeros(padded_image_shape, dtype=dtype) kx = (np.fft.fftfreq(padded_image_shape[0])) ky = (np.fft.fftfreq(padded_image_shape[1])) # kz = (np.fft.fftfreq(self._ft_images.shape[3])) * 0.5 # kx, ky, kz = np.meshgrid(kx, ky, kz, indexing='ij') kx, ky = np.meshgrid(kx, ky, indexing='ij') images_shape = np.asarray(ims.data.shape[:3], dtype=np.long) images_shape = tuple(images_shape) if self.cache_image == "": shifted_images = np.empty(images_shape) else: shifted_images = np.memmap(self.cache_image, dtype=np.float, mode='w+', shape=images_shape) # print(shifts.shape) # print(kx.shape, ky.shape) # print shifts shifts_in_pixels = np.copy(shifts) try: shifts_in_pixels[:, 0] = shifts[:, 0] / mdh.voxelsize.x shifts_in_pixels[:, 1] = shifts[:, 1] / mdh.voxelsize.y # shifts_in_pixels[:, 2] = shifts[:, 2] / mdh.voxelsize.z # shifts_in_pixels[np.isnan(shifts_in_pixels)] = 0 # print mdh if mdh.voxelsize.units == 'um': # print('um units') shifts_in_pixels /= 1E3 except Exception as e: Warning("Failed at converting drift in pixels to real distances") repr(e) # print shifts_in_pixels for i in np.arange(ims.data.shape[2]): # print i padded_image[padding[0, 0]:padding[0, 0] + ims.data.shape[0], padding[1, 0]:padding[1, 0] + ims.data.shape[1]] = ims.data[:, :, i].squeeze() ft_image = np.fft.fftn(padded_image) data_shifted = shift_image_direct_rough(ft_image, shifts_in_pixels[i], kxy=(kx, ky)) shifted_images[:, :, i] = data_shifted[padding[0, 0]:padding[0, 0] + ims.data.shape[0], padding[1, 0]:padding[1, 0] + ims.data.shape[1]] if ((i + 1) % max(shifted_images.shape[-1] // 5, 1) == 0): if isinstance(shifted_images, np.memmap): shifted_images.flush() print("{:.2f} s. Completed shifting {} of {} total images.". format(time.time() - self._start_time, i + 1, shifted_images.shape[-1])) return shifted_images
class AlignPSF(ModuleBase): """ Align PSF stacks by redundant cross correlation. """ inputName = Input('psf_cropped') normalize_z = Bool(True) tukey = Float(0.50) rcc_tolerance = Float(5.0) z_crop_half_roi = Int(15) peak_detect = Enum(['Gaussian', 'RBF']) debug = Bool(False) output_cross_corr_images = Output('cross_cor_img') output_cross_corr_images_fitted = Output('cross_cor_img_fitted') output_images = Output('psf_aligned') def execute(self, namespace): self._namespace = namespace ims = namespace[self.inputName] # X, Y, Z, 'C' psf_stack = ims.data[:,:,:,:] z_slice = slice(psf_stack.shape[2]//2-self.z_crop_half_roi, psf_stack.shape[2]//2+self.z_crop_half_roi+1) cleaned_psf_stack = self.normalize_images(psf_stack[:,:,z_slice,:]) if self.tukey > 0: masks = [signal.tukey(dim_len, self.tukey) for dim_len in cleaned_psf_stack.shape[:3]] masks = np.product(np.meshgrid(*masks, indexing='ij'), axis=0) cleaned_psf_stack *= masks[:,:,:,None] drifts = self.calculate_shifts(cleaned_psf_stack, self.rcc_tolerance * 1E3 / ims.mdh['voxelsize.x']) # print drifts namespace[self.output_images] = ImageStack(self.shift_images(cleaned_psf_stack if self.debug else psf_stack, drifts), mdh=ims.mdh) def normalize_images(self, psf_stack): # in case it is already bg subtracted cleaned_psf_stack = np.clip(psf_stack, 0, None) # substact bg per stack cleaned_psf_stack -= cleaned_psf_stack.min(axis=(0,1,2), keepdims=True) if self.normalize_z: # normalize intensity per plane cleaned_psf_stack /= cleaned_psf_stack.max(axis=(0,1), keepdims=True) / 1.05 else: # normalize intensity per psf stack cleaned_psf_stack /= cleaned_psf_stack.max(axis=(0,1,2), keepdims=True) / 1.05 cleaned_psf_stack -= 0.05 np.clip(cleaned_psf_stack, 0, None, cleaned_psf_stack) return cleaned_psf_stack def calculate_shifts(self, psf_stack, drift_tolerance): n_steps = psf_stack.shape[3] coefs_size = n_steps * (n_steps-1) / 2 coefs = np.zeros((coefs_size, n_steps-1)) shifts = np.zeros((coefs_size, 3)) output_cross_corr_images = np.zeros((psf_stack.shape[0], psf_stack.shape[1], psf_stack.shape[2], coefs_size), dtype=np.float) output_cross_corr_images_fitted = np.zeros((psf_stack.shape[0], psf_stack.shape[1], psf_stack.shape[2], coefs_size), dtype=np.float) counter = 0 for i in np.arange(0, n_steps - 1): for j in np.arange(i+1, n_steps): coefs[counter, i:j] = 1 print "compare {} to {}".format(i, j) correlate_result = signal.correlate(psf_stack[:,:,:,i], psf_stack[:,:,:,j], mode="same") correlate_result -= correlate_result.min() correlate_result /= correlate_result.max() threshold = 0.50 correlate_result[correlate_result<threshold] = np.nan labeled_image, labeled_counts = ndimage.label(~np.isnan(correlate_result)) # print(labeled_counts) # protects against > 1 peak in the cross correlation results # shouldn't happen anyway, but at least avoid fitting a single to multi-modal data if labeled_counts > 1: max_order = np.argsort(ndimage.maximum(correlate_result, labeled_image, np.arange(labeled_counts)+1))+1 correlate_result[labeled_image!=max_order[0]] = np.nan output_cross_corr_images[:,:,:,counter] = np.nan_to_num(correlate_result) dims = list() for _, dim in enumerate(correlate_result.shape): dims.append(np.arange(dim)) dims[-1] = dims[-1] - dims[-1].mean() # peaks = np.nonzero(correlate_result==np.nanmax(correlate_result)) if self.peak_detect == "Gaussian": res = optimize.least_squares(guassian_nd_error, [1, 0, 0, 5., 0, 5., 0, 30.], args=(dims, correlate_result)) output_cross_corr_images_fitted[:,:,:,counter] = gaussian_nd(res.x, dims) # print("Gaussian") # print("chi2: {}".format(np.sum(np.square(res.fun))/(res.fun.shape[0]-8))) # print("fitted parameters: {}".format(res.x)) # res = optimize.least_squares(guassian_sq_nd_error, # [1, 0, 0, 3., 0, 3., 0, 20.], # args=(dims, correlate_result)) # output_cross_corr_images_fitted[:,:,:,counter] = gaussian_sq_nd(res.x, dims) # print("Gaussian 2") # print("chi2: {}".format(np.sum(np.square(res.fun))/(res.fun.shape[0]-8))) # print("fitted parameters: {}".format(res.x)) # # res = optimize.least_squares(lorentzian_nd_error, # [1, 0, 0, 2., 0, 2., 0, 10.], # args=(dims, correlate_result)) # output_cross_corr_images_fitted[:,:,:,counter] = lorentzian_nd(res.x, dims) # print("lorentzian") # print("chi2: {}".format(np.sum(np.square(res.fun))/(res.fun.shape[0]-8))) # print("fitted parameters: {}".format(res.x)) shifts[counter, 0] = res.x[2] shifts[counter, 1] = res.x[4] shifts[counter, 2] = res.x[6] elif self.peak_detect == "RBF": rbf_interpolator = build_rbf(dims, correlate_result) res = optimize.minimize(rbf_nd_error, [correlate_result.shape[0]*0.5, correlate_result.shape[1]*0.5, correlate_result.shape[2]*0.5], args=rbf_interpolator) output_cross_corr_images_fitted[:,:,:,counter] = rbf_nd(rbf_interpolator, dims) # print(res.x) shifts[counter, :] = res.x else: raise Exception("peak founding method not recognised") # print("fitted parameters: {}".format(res.x)) counter += 1 self._namespace[self.output_cross_corr_images] = ImageStack(output_cross_corr_images) self._namespace[self.output_cross_corr_images_fitted] = ImageStack(output_cross_corr_images_fitted) drifts = np.matmul(np.linalg.pinv(coefs), shifts) residuals = np.matmul(coefs, drifts) - shifts residuals_dist = np.linalg.norm(residuals, axis=1) # shift_max = self.rcc_tolerance * 1E3 / mdh['voxelsize.x'] shift_max = drift_tolerance # Sort and mask residual errors residuals_arg = np.argsort(-residuals_dist) residuals_arg = residuals_arg[residuals_dist[residuals_arg] > shift_max] # Remove coefs rows # Descending from largest residuals to small # Only if matrix remains full rank coefs_temp = np.empty_like(coefs) counter = 0 for i, index in enumerate(residuals_arg): coefs_temp[:] = coefs coefs_temp[index, :] = 0 if np.linalg.matrix_rank(coefs_temp) == coefs.shape[1]: coefs[:] = coefs_temp # print("index {} with residual of {} removed".format(index, residuals_dist[index])) counter += 1 else: print("Could not remove all residuals over shift_max threshold.") break print("removed {} in total".format(counter)) drifts = np.matmul(np.linalg.pinv(coefs), shifts) drifts = np.pad(drifts, [[1,0],[0,0]], 'constant', constant_values=0) np.cumsum(drifts, axis=0, out=drifts) psf_stack_mean = psf_stack / psf_stack.mean(axis=(0,1,2), keepdims=True) psf_stack_mean = psf_stack_mean.mean(axis=3) psf_stack_mean *= psf_stack_mean > psf_stack_mean.max() * 0.5 center_offset = ndimage.center_of_mass(psf_stack_mean) - np.asarray(psf_stack_mean.shape)*0.5 # print(center_offset) # print drifts.shape # print stats.trim_mean(drifts, 0.25, axis=0) # drifts = drifts - stats.trim_mean(drifts, 0.25, axis=0) drifts = drifts - center_offset if True: try: # from matplotlib import pyplot fig, axes = pyplot.subplots(1, 2, figsize=(6,3)) # new_residuals = np.matmul(coefs, drifts) - shifts # new_residuals_dist = np.linalg.norm(new_residuals, axis=1) # # print new_residuals_dist # pyplot.hist(new_residuals_dist[coefs.any(axis=1)], 100) # print drifts # limits = np.max(np.abs(drifts), axis=0) axes[0].scatter(drifts[:,0], drifts[:,1], s=50) # axes[0].set_xlim(-limits[0], limits[0]) # axes[0].set_ylim(-limits[1], limits[1]) axes[0].set_xlabel('x') axes[0].set_ylabel('y') axes[1].scatter(drifts[:,0], drifts[:,2], s=50) axes[1].set_xlabel('x') axes[1].set_ylabel('z') for ax in axes: # ax.set_xlim(-1, 1) # ax.set_ylim(-1, 1) ax.axvline(0, color='red', ls='--') ax.axhline(0, color='red', ls='--') fig.tight_layout() except Exception as e: print e return drifts def shift_images(self, psf_stack, shifts): kx = (np.fft.fftfreq(psf_stack.shape[0])) ky = (np.fft.fftfreq(psf_stack.shape[1])) kz = (np.fft.fftfreq(psf_stack.shape[2])) kx, ky, kz = np.meshgrid(kx, ky, kz, indexing='ij') shifted_images = np.zeros_like(psf_stack) for i in np.arange(psf_stack.shape[3]): psf = psf_stack[:,:,:,i] ft_image = np.fft.fftn(psf) shift = shifts[i] shifted_images[:,:,:,i] = np.abs(np.fft.ifftn(ft_image*np.exp(-2j*np.pi*(kx*shift[0] + ky*shift[1] + kz*shift[2])))) # shifted_images.append(shifted_image) return shifted_images
class Binning(CacheCleanupModule): """ Downsample 3D data (mean). X, Y pixels that does't fill a full bin are dropped. Pixels in the 3rd dimension can have a partially filled bin. Inputs ------ inputName : ImageStack Outputs ------- outputName : ImageStack Parameters ---------- x_start : int Starting index in x. x_end : Float Stopping index in x. y_start : Float Starting index in y. y_end : Float Stopping index in y. binsize : Float Bin size. cache_bin : File Use file as disk cache if provided. """ inputName = Input('input') x_start = Int(0) x_end = Int(-1) y_start = Int(0) y_end = Int(-1) # z_start = Int(0) # z_end = Int(-1) binsize = List([1, 1, 1], minlen=3, maxlen=3) cache_bin = File("binning_cache_2.bin") outputName = Output('binned_image') def _execute(self, namespace): self._start_time = time.time() ims = namespace[self.inputName] binsize = np.asarray(self.binsize, dtype=np.int) # print (binsize) # unconventional, end stop in inclusive x_slice = np.arange(ims.data.shape[0] + 1)[slice( self.x_start, self.x_end, 1)] y_slice = np.arange(ims.data.shape[1] + 1)[slice( self.y_start, self.y_end, 1)] x_slice = x_slice[:x_slice.shape[0] // binsize[0] * binsize[0]] y_slice = y_slice[:y_slice.shape[0] // binsize[1] * binsize[1]] # print x_slice, len(x_slice) # print y_slice, len(y_slice) bincounts = np.asarray([ len(x_slice) // binsize[0], len(y_slice) // binsize[1], -(-ims.data.shape[2] // binsize[2]) ], dtype=np.long) x_slice_ind = slice(x_slice[0], x_slice[-1] + 1) y_slice_ind = slice(y_slice[0], y_slice[-1] + 1) # print (bincounts) new_shape = np.stack([bincounts, binsize], -1).flatten() # print(new_shape) # need to wrap this to work for multiply color channel images # binned_image = ims.data[:,:,:].reshape(new_shape) dtype = ims.data[:, :, 0].dtype # print bincounts binned_image = np.memmap(self.cache_bin, dtype=dtype, mode='w+', shape=tuple( np.asarray(bincounts, dtype=np.long))) # print binned_image.shape new_shape_one_chunk = new_shape.copy() new_shape_one_chunk[4] = 1 new_shape_one_chunk[5] = -1 # print new_shape_one_chunk progress = 0.2 * ims.data.shape[2] # print for i, f in enumerate(np.arange(0, ims.data.shape[2], binsize[2])): raw_data_chunk = ims.data[x_slice_ind, y_slice_ind, f:f + binsize[2]].squeeze() binned_image[:, :, i] = raw_data_chunk.reshape(new_shape_one_chunk).mean( (1, 3, 5)).squeeze() if (f + binsize[2] >= progress): binned_image.flush() progress += 0.2 * ims.data.shape[2] print("{:.2f} s. Completed binning {} of {} total images.". format(time.time() - self._start_time, min(f + binsize[2], ims.data.shape[2]), ims.data.shape[2])) # print(type(binned_image)) im = ImageStack(binned_image, titleStub=self.outputName) # print(type(im.data)) im.mdh.copyEntriesFrom(ims.mdh) im.mdh['Parent'] = ims.filename try: ### Metadata must be logged correctly for the measured drift to be applicable to the source image im.mdh['voxelsize.x'] *= binsize[0] im.mdh['voxelsize.y'] *= binsize[1] # im.mdh['voxelsize.z'] *= binsize[2] if 'recipe.binning' in im.mdh.keys(): im.mdh['recipe.binning'] = binsize * im.mdh['recipe.binning'] else: im.mdh['recipe.binning'] = binsize except: pass namespace[self.outputName] = im
class LabelRange(Filter): """Asigns a unique integer label to each contiguous region in the input mask. Throws away all regions which are outside of given number of pixel range. Also uses the number of sites from a second input channel to decide if region is retained, retaining only those with the number sites in a given range. """ inputSitesLabeled = Input( "sites") # sites and the main input must have the same shape! minRegionPixels = Int(10) maxRegionPixels = Int(100) minSites = Int(4) maxSites = Int(6) sitesAsMaxima = Bool(False) def filter(self, image, imagesites): #from PYME.util.shmarray import shmarray #import multiprocessing if self.processFramesIndividually: filt_ims = [] for chanNum in range(image.data.shape[3]): filt_ims.append( np.concatenate([ np.atleast_3d( self.applyFilter( image.data[:, :, i, chanNum].squeeze().astype('f'), imagesites.data[:, :, i, chanNum].squeeze().astype('f'), chanNum, i, image)) for i in range(image.data.shape[2]) ], 2)) else: filt_ims = [ np.atleast_3d( self.applyFilter( image.data[:, :, :, chanNum].squeeze().astype('f'), imagesites.data[:, :, :, chanNum].squeeze().astype('f'), chanNum, 0, image)) for chanNum in range(image.data.shape[3]) ] im = ImageStack(filt_ims, titleStub=self.outputName) im.mdh.copyEntriesFrom(image.mdh) im.mdh['Parent'] = image.filename self.completeMetadata(im) return im def execute(self, namespace): namespace[self.outputName] = self.filter( namespace[self.inputName], namespace[self.inputSitesLabeled]) def applyFilter(self, data, sites, chanNum, frNum, im): # siteLabels = self.recipe.namespace[self.sitesLabeled] mask = data > 0.5 labs, nlabs = ndimage.label(mask) rSize = self.minRegionPixels rMax = self.maxRegionPixels minSites = self.minSites maxSites = self.maxSites m2 = 0 * mask objs = ndimage.find_objects(labs) for i, o in enumerate(objs): r = labs[o] == i + 1 #print r.shape area = r.sum() if (area >= rSize) and (area <= rMax): if self.sitesAsMaxima: nsites = sites[o][r].sum() else: nsites = (np.unique(sites[o][r]) > 0).sum( ) # count the unique labels (excluding label 0 which is background) if (nsites >= minSites) and (nsites <= maxSites): m2[o] += r labs, nlabs = ndimage.label(m2 > 0) return labs def completeMetadata(self, im): im.mdh['Labelling.MinSize'] = self.minRegionPixels im.mdh['Labelling.MaxSize'] = self.maxRegionPixels im.mdh['Labelling.MinSites'] = self.minSites im.mdh['Labelling.MaxSites'] = self.maxSites
class RCCDriftCorrection(RCCDriftCorrectionBase): """ For localization data. Performs drift correction using cross-correlation, including redundant RCC from Wang et al. Optics Express 2014 22:13 (Bo Huang's RCC algorithm). Runtime will vary hugely depending on size of dataset and settings. ``cache_fft`` is necessary for large datasets. Inputs ------ input_for_correction : Tabular Dataset used to calculate drift. input_for_mapping : Tabular *Deprecated.* Dataset to correct. Outputs ------- output_drift : Tuple of arrays Drift results. output_drift_plot : Plot *Deprecated.* Plot of drift results. output_cross_cor : ImageStack Cross correlation images if ``debug_cor_file`` is not blank. outputName : Tabular *Deprecated.* Drift-corrected dataset. Parameters ---------- step : Int Setting for image construction. Step size between images window : Int Setting for image construction. Number of frames used per image. Should be equal or larger than step size. binsize : Float Setting for image construction. Pixel size. flatten_z : Bool Setting for image construction. Ignore z information if enabled. tukey_size : Float Setting for image construction. Shape parameter for Tukey filter (``scipy.signal.tukey``). cache_fft : File Use file as disk cache if provided. method : String Redundant, mean, or direct cross-correlation. shift_max : Float Rejection threshold for RCC. corr_window : Float Size of correlation window. Frames are only compared if within this frame range. N/A for DCC. multiprocessing : Float Enables multiprocessing. debug_cor_file : File Enables debugging. Use file as disk cache if provided. """ input_for_correction = Input('Localizations') input_for_mapping = Input('Localizations') # redundant cross-corelation, mean cross-correlation, direction cross-correlation step = Int(2500) window = Int(2500) binsize = Float(30) flatten_z = Bool() tukey_size = Float(0.25) outputName = Output('corrected_localizations') def calc_corr_drift_from_locs(self, x, y, z, t): # bin edges for histogram bx = np.arange(x.min(), x.max() + self.binsize + 1, self.binsize) by = np.arange(y.min(), y.max() + self.binsize + 1, self.binsize) bz = np.arange(z.min(), z.max() + self.binsize + 1, self.binsize) # pad bin length to odd number so image size is even if bx.shape[0] % 2 == 0: bx = np.concatenate([bx, [bx[-1] + bx[1] - bx[0]]]) if by.shape[0] % 2 == 0: by = np.concatenate([by, [by[-1] + by[1] - by[0]]]) if bz.shape[0] > 2 and bz.shape[0] % 2 == 0: bz = np.concatenate([bz, [bz[-1] + bz[1] - bz[0]]]) assert (bx.shape[0] % 2 == 1) and (by.shape[0] % 2 == 1), "Ops. Image not correctly padded to even size." # start time of all windows, allow partial window near end of pipeline time_values = np.arange(t.min(), t.max() + 1, self.step) # 2d array, start and end time of windows time_values = np.stack([time_values, np.clip(time_values + self.window, None, t.max())], axis=1) n_steps = time_values.shape[0] # center time of center for returning. last window may have different spacing time_values_mid = time_values.mean(axis=1) if (np.any(np.diff(t) < 0)): # in case pipeline is not sorted for whatever reason t_sort_arg = np.argsort(t) t = t[t_sort_arg] x = x[t_sort_arg] y = y[t_sort_arg] z = z[t_sort_arg] time_indexes = np.zeros_like(time_values, dtype=int) time_indexes[:, 0] = np.searchsorted(t, time_values[:, 0], side='left') time_indexes[:, 1] = np.searchsorted(t, time_values[:, 1]-1, side='right') # print('time indexes') # print(time_values) # print(time_values_mid) # print(time_indexes) # Fourier transformed (and binned) set of images to correlate against # one another # Crude way of swaping longest axis to the last for optimizing rfft performance. # Code changed for this is limited to this method. xyz = np.asarray([x, y, z]) bxyz = np.asarray([bx, by, bz]) dims_order = np.arange(len(xyz)) dims_length = np.asarray([len(b) for b in bxyz]) dims_largest_index = np.argmax(dims_length) dims_order[-1], dims_order[dims_largest_index] = dims_order[dims_largest_index], dims_order[-1] xyz = xyz[dims_order] bxyz = bxyz[dims_order] dims_length = dims_length[dims_order] # use memmap for caching if ft_cache is defined if self.cache_fft == "": ft_images = np.zeros((n_steps, dims_length[0]-1, dims_length[1]-1, (dims_length[2]-1)//2 + 1, ), dtype=np.complex) else: ft_images = np.memmap(self.cache_fft, dtype=np.complex, mode='w+', shape=(n_steps, dims_length[0]-1, dims_length[1]-1, (dims_length[2]-1)//2 + 1, )) print(ft_images.shape) print("{:,} bytes".format(ft_images.nbytes)) print("{:.2f} s. About to start heavy lifting.".format(time.time() - self._start_time)) # fill ft_images # if multiprocessing, can either use or not caching # if not multiprocessing, don't pass filenames for caching, just the memmap array is fine if self.multiprocessing: dt = ft_images.dtype sh = ft_images.shape args = [(i, xyz[:,slice(*ti)].T, bxyz, (self.cache_fft, dt, sh, i), self.tukey_size) for i, ti in enumerate(time_indexes)] for i, (j, res) in enumerate(self._pool.imap_unordered(calc_fft_from_locs_helper, args)): if self.cache_fft == "": ft_images[j] = res if ((i+1) % (n_steps//5) == 0): print("{:.2f} s. Completed calculating {} of {} total ft images.".format(time.time() - self._start_time, i+1, n_steps)) else: # For each window we wish to correlate... for i, ti in enumerate(time_indexes): # .. we generate an image and store ft of image t_slice = slice(*ti) ft_images[i] = calc_fft_from_locs(xyz[:,t_slice].T, bxyz, filter_size=self.tukey_size) if ((i+1) % (n_steps//5) == 0): print("{:.2f} s. Completed calculating {} of {} total ft images.".format(time.time() - self._start_time, i+1, n_steps)) print("{:.2f} s. Finished generating ft array.".format(time.time() - self._start_time)) print("{:,} bytes".format(ft_images.nbytes)) shifts, coefs = self.calc_corr_drift_from_ft_images(ft_images) # clean up of ft_images, potentially really large array if isinstance(ft_images, np.memmap): ft_images.flush() del ft_images # print(shifts) # print(coefs) return time_values_mid, self.binsize * shifts[:, dims_order], coefs def _execute(self, namespace): # from PYME.util import mProfile # self._start_time = time.time() self.trait_setq(**{"_start_time": time.time()}) print("Starting drift correction module.") if self.multiprocessing: proccess_count = np.clip(multiprocessing.cpu_count()-1, 1, None) self.trait_setq(**{"_pool": multiprocessing.Pool(processes=proccess_count)}) locs = namespace[self.input_for_correction] # mProfile.profileOn(['localisations.py', 'processing.py']) drift_res = self.calc_corr_drift_from_locs(locs['x'], locs['y'], locs['z'] * (0 if self.flatten_z else 1), locs['t']) t_shift, shifts = self.rcc(self.shift_max, *drift_res) # mProfile.profileOff() # mProfile.report() if self.multiprocessing: self._pool.close() self._pool.join() # convert frame-to-frame drift to drift from origin shifts = np.cumsum(shifts, 0) out = tabular.mappingFilter(namespace[self.input_for_mapping]) t_out = out['t'] # cubic interpolate with no smoothing dx = interpolate.CubicSpline(t_shift, shifts[:, 0])(t_out) dy = interpolate.CubicSpline(t_shift, shifts[:, 1])(t_out) dz = interpolate.CubicSpline(t_shift, shifts[:, 2])(t_out) if 'dx' in out.keys(): # getting around oddity with mappingFilter # addColumn adds a new column but also keeps the old column # __getitem__ returns the new column # but mappings usues the old column # Wrap with another level of mappingFilter so the new column becomes the 'old column' out.addColumn('dx', dx) out.addColumn('dy', dy) out.addColumn('dz', dz) out = tabular.mappingFilter(out) # out.mdh = namespace[self.input_localizations].mdh out.setMapping('x', 'x + dx') out.setMapping('y', 'y + dy') out.setMapping('z', 'z + dz') else: out.addColumn('dx', dx) out.addColumn('dy', dy) out.addColumn('dz', dz) out.setMapping('x', 'x + dx') out.setMapping('y', 'y + dy') out.setMapping('z', 'z + dz') # propagate metadata, if present try: out.mdh = locs.mdh except AttributeError: pass namespace[self.outputName] = out namespace[self.output_drift] = t_shift, shifts # non essential, only for plotting out drift data namespace[self.output_drift_plot] = Plot(partial(generate_drift_plot, t_shift, shifts)) namespace[self.output_cross_cor] = self._cc_image
class CalculateFRCFromImages(CalculateFRCBase): """ Take a pair of images and calculates the fourier shell/ring correlation (FSC / FRC). Inputs ------ input_image_a : ImageStack First of two images. Outputs ------- output_fft_images_cc : ImageStack Fast Fourier transform original and cross-correlation images. output_frc_dict : dict FSC/FRC results. output_frc_plot : Plot Output plot of the FSC / FRC curve. output_frc_raw : dict Complete FSC/FRC results. Parameters ---------- image_b_path : File File path of the second of the two images. c_channel : int Color channel of the images to use. image_a_z : int Ignored unless flatten_z is True. In which case either select the z plane to use (>=0) or performs a maximum project (<0) for the first image. image_b_z : int Ignored unless flatten_z is True. In which case either select the z plane to use (>=0) or performs a maximum project (<0) for the second image. flatten_z : Bool If enabled ignores z information and only performs a FRC. pre_filter : string Methods to filter the images prior to Fourier transform. frc_smoothing_func : string Methods to smooth the FSC / FRC curve. cubic_smoothing : float Smoothing factor for cubic spline. multiprocessing : Bool Enables multiprocessing. save_path : File (Optional) File path to save output """ input_image_a = Input('input') # image_a_dim = Int(2) # image_a_index = Int(0) # image_b_dim = Int(2) # image_b_index = Int(1) image_b_path = File(info_text="Filepath of image to compare against. Leave blank to compare against currently opened image.") c_channel = Int(0) flatten_z = Bool(True) image_a_z = Int(-1) image_b_z = Int(-1) def execute(self, namespace): self._namespace = namespace import multiprocessing # from PYME.util import mProfile # mProfile.profileOn(["frc.py"]) if self.multiprocessing: proccess_count = np.clip(2, 1, multiprocessing.cpu_count()-1) self._pool = multiprocessing.Pool(processes=proccess_count) # image_pair = self.generate_image_pair(mapped_pipeline) # ims = namespace[self.input_images] image_a = namespace[self.input_image_a] if len(self.image_b_path.strip()) == 0: image_b = image_a else: image_b = ImageStack(filename=self.image_b_path) self._pixel_size_in_nm = np.zeros(3, dtype=np.float) self._pixel_size_in_nm[0] = image_a.mdh.voxelsize.x self._pixel_size_in_nm[1] = image_a.mdh.voxelsize.y try: self._pixel_size_in_nm[2] = image_a.mdh.voxelsize.z except: pass if image_a.mdh.voxelsize.units == 'um': self._pixel_size_in_nm *= 1.E3 # print(self._pixel_size_in_nm) # image_indices = [[self.image_a_dim, self.image_a_index], [self.image_b_dim, self.image_b_index]] # image_slices = list() # for i in xrange(2): # slices = [slice(None, None), slice(None, None)] # for j in xrange(2, image_indices[i][0]+1): # if j == image_indices[i][0]: # slices.append(slice(image_indices[i][1], image_indices[i][1]+1)) # else: # slices.append(slice(None, None)) # image_slices.append(slices) # # image_pair = [ims.data[image_slices[0]].squeeze(), ims.data[image_slices[1]].squeeze()] image_a_data = image_a.data[:,:,:,self.c_channel].squeeze() image_b_data = image_b.data[:,:,:,self.c_channel].squeeze() if self.flatten_z: print("2D mode. Slice if z index >= 0 otherwise max projection") if self.image_a_z >= 0: image_a_data = image_a_data[:,:,self.image_a_z] else: image_a_data = image_a_data.max(2) if self.image_b_z >= 0: image_b_data = image_b_data[:,:,self.image_b_z] else: image_b_data = image_b_data.max(2) # print(np.allclose(image_a_data, image_b_data)) image_pair = [image_a_data, image_b_data] # print(image_pair[0].shape) image_pair = self.preprocess_images(image_pair) frc_res, rawdata = self.calculate_FRC_from_images(image_pair, None) namespace[self.output_frc_dict] = frc_res namespace[self.output_frc_raw] = rawdata if self.multiprocessing: self._pool.close() self._pool.join() # mProfile.profileOff() # mProfile.report() self.save_to_file(namespace)