class DriftOutput(ModuleBase): """ Save drift data to a file. Inputs ------ input_name : Drift measured from localization or image dataset. Outputs ------- output_dummy : None Blank output. Required to run correctly. Parameters ---------- save_path : File Filepath to save drift data. """ input_name = Input('drift') save_path = File('drift') output_dummy = Output('dummy') # will not run execute without this def execute(self, namespace, context={}): # out_filename = self.filePattern.format(**context) out_filename = self.save_path tIndex, drift = namespace[self.input_name] np.savez_compressed(out_filename, tIndex=tIndex, drift=drift) print('saved')
class FileSource(PointSource): file = File() #name = Str('Points File') def getPoints(self): import numpy as np return np.load(self.file)
class FileSource(PointSource): file = File() #name = Str('Points File') def getPoints(self): import numpy as np return np.load(self.file) def genMetaData(self, mdh): mdh['GeneratedPoints.Source.Type'] = 'File' mdh['GeneratedPoints.Source.FileName'] = self.file
class CombineBeadStacks(ModuleBase): """ Combine multiply bead stacks in the 4th dimension. X, Y, Z must be identical. """ inputName = Input('dummy') files = List(File, ['', ''], 2) cache = File() outputName = Output('bead_images') def execute(self, namespace): ims = ImageStack(filename=self.files[0]) dims = np.asarray(ims.data.shape, dtype=np.long) dims[3] = 0 dtype_ = ims.data[:,0,0,0].dtype mdh = ims.mdh del ims for fil in self.files: ims = ImageStack(filename=fil) dims[3] += ims.data.shape[3] del ims if self.cache != '': raw_data = np.memmap(self.cache, dtype=dtype_, mode='w+', shape=tuple(dims)) else: raw_data = np.zeros(shape=tuple(dims), dtype=dtype_) counter = 0 for fil in self.files: ims = ImageStack(filename=fil) c_len = ims.data.shape[3] data = ims.data[:,:,:,:] data.shape += (1,) * (4 - data.ndim) raw_data[:,:,:,counter:counter+c_len] = data counter += c_len del ims new_mdh = None try: new_mdh = MetaDataHandler.NestedClassMDHandler(mdh) new_mdh["PSFExtraction.SourceFilenames"] = self.files except Exception as e: print(e) namespace[self.outputName] = ImageStack(data=raw_data, mdh=new_mdh)
class svmSegment(Filter): classifier = File('') def _loadClassifier(self): from PYME.Analysis import svmSegment if not '_cf' in dir(self): self._cf = svmSegment.svmClassifier(filename=self.classifier) def applyFilter(self, data, chanNum, frNum, im): self._loadClassifier() return self._cf.classify(data.astype('f')) def completeMetadata(self, im): im.mdh['SVMSegment.classifier'] = self.classifier
class CNNFilter(Filter): """ Use a previously trained Keras neural network to filter the data. Used for learnt de-noising and/or deconvolution. Runs prediction piecewise over the image with over-lapping ROIs and averages the prediction results. Notes ----- Keras and either Tensorflow or Theano must be installed set up for this module to work. This is not a default dependency of python-microscopy as the conda-installable versions don't have GPU support. """ model = File('') step_size = Int(14) def _load_model(self): from keras.models import load_model if not getattr(self, '_model_name', None) == self.model: self._model_name = self.model self._model = load_model( self._model_name) #TODO - make cluster-aware def applyFilter(self, data, chanNum, frNum, im): self._load_model() out = np.zeros(data.shape, 'f') _, kernel_size_x, kernel_size_y, _ = self._model.input_shape scale_factor = 1. / kernel_size_x * kernel_size_y / (float( self.step_size)**2) for i_x in range(0, out.shape[0] - kernel_size_x, self.step_size): for i_y in range(0, out.shape[1] - kernel_size_y, self.step_size): d_i = data[i_x:(i_x + kernel_size_x), i_y:(i_y + kernel_size_y)].reshape( 1, kernel_size_x, kernel_size_y, 1) p = self._model.predict(d_i).reshape(kernel_size_x, kernel_size_y) out[i_x:(i_x + kernel_size_x), i_y:(i_y + kernel_size_y)] += scale_factor * p return out def completeMetadata(self, im): im.mdh['CNNFilter.model'] = self.model
class LoadDrift(ModuleBase): """ *Deprecated.* Use ``LoadDriftandInterp`` instead. Load drift from a file. """ input_dummy = Input('input') # breaks GUI without this??? load_path = File() output_drift_raw = Input('drift_raw') output_drift_plot = Output('drift_plot') def execute(self, namespace): data = np.load(self.load_path) tIndex = data['tIndex'] drift = data['drift'] namespace[self.output_drift_raw] = (tIndex, drift) # non essential, only for plotting out drift data namespace[self.output_drift_plot] = Plot( partial(generate_drift_plot, tIndex, drift))
class RCCDriftCorrectionBase(CacheCleanupModule): """ Performs drift correction using redundant cross-correlation from Wang et al. Optics Express 2014 22:13 (Bo Huang's RCC algorithm). Base class for other RCC recipes. Can take cached fft input (as filename, not an 'input'). Only output drift as tuple of time points, and drift amount. Currently not registered by itself since not very usefule. """ cache_fft = File("rcc_cache.bin") method = Enum(['RCC', 'MCC', 'DCC']) # redundant cross-corelation, mean cross-correlation, direct cross-correlation shift_max = Float(5) # nm corr_window = Int(5) multiprocessing = Bool() debug_cor_file = File() output_drift = Output('drift') output_drift_plot = Output('drift_plot') # if debug_cor_file not blank, filled with imagestack of cross correlation output_cross_cor = Output('cross_cor') def calc_corr_drift_from_ft_images(self, ft_images): n_steps = ft_images.shape[0] # Matrix equation coefficient matrix # Shape can be predetermined based on method if self.method == "DCC": coefs_size = n_steps - 1 elif self.corr_window > 0: coefs_size = n_steps * self.corr_window - self.corr_window * ( self.corr_window + 1) // 2 else: coefs_size = n_steps * (n_steps - 1) // 2 coefs = np.zeros((coefs_size, n_steps - 1)) shifts = np.zeros((coefs_size, 3)) counter = 0 ft_1_cache = list() ft_2_cache = list() autocor_shift_cache = list() # print self.debug_cor_file if not self.debug_cor_file == "": cc_file_shape = [ shifts.shape[0], ft_images.shape[1], ft_images.shape[2], (ft_images.shape[3] - 1) * 2 ] # flatten shortest dimension to reduce cross correlation to 2d images for easier debugging min_arg = min(enumerate(cc_file_shape[1:]), key=lambda x: x[1])[0] + 1 cc_file_shape.pop(min_arg) cc_file_args = (self.debug_cor_file, np.float, tuple(cc_file_shape)) cc_file = np.memmap(cc_file_args[0], dtype=cc_file_args[1], mode="w+", shape=cc_file_args[2]) # del cc_file cc_args = zip(range(shifts.shape[0]), (cc_file_args, ) * shifts.shape[0]) else: cc_args = (None, ) * shifts.shape[0] # For each ft image, calculate correlation for i in np.arange(0, n_steps - 1): if self.method == "DCC" and i > 0: break ft_1 = ft_images[i, :, :] autocor_shift = calc_shift(ft_1, ft_1) for j in np.arange(i + 1, n_steps): if (self.method != "DCC") and (self.corr_window > 0) and ( j - i > self.corr_window): break ft_2 = ft_images[j, :, :] coefs[counter, i:j] = 1 # if multiprocessing, use cache when defined if self.multiprocessing: # if reading ft_images from cache, replace ft_1 and ft_2 with their indices if not self.cache_fft == "": ft_1 = i ft_2 = j ft_1_cache.append(ft_1) ft_2_cache.append(ft_2) autocor_shift_cache.append(autocor_shift) else: shifts[counter, :] = calc_shift(ft_1, ft_2, autocor_shift, None, cc_args[counter]) if ((counter + 1) % max(coefs_size // 5, 1) == 0): print( "{:.2f} s. Completed calculating {} of {} total shifts." .format(time.time() - self._start_time, counter + 1, coefs_size)) counter += 1 if self.multiprocessing: args = zip( range(len(autocor_shift_cache)), ft_1_cache, ft_2_cache, autocor_shift_cache, len(ft_1_cache) * ((self.cache_fft, ft_images.dtype, ft_images.shape), ), cc_args) for i, (j, res) in enumerate( self._pool.imap_unordered(calc_shift_helper, args)): shifts[j, ] = res if ((i + 1) % max(coefs_size // 5, 1) == 0): print( "{:.2f} s. Completed calculating {} of {} total shifts." .format(time.time() - self._start_time, i + 1, coefs_size)) print("{:.2f} s. Finished calculating all shifts.".format( time.time() - self._start_time)) print("{:,} bytes".format(coefs.nbytes)) print("{:,} bytes".format(shifts.nbytes)) if not self.debug_cor_file == "": # move time axis for ImageStack cc_file = np.moveaxis(cc_file, 0, 2) self.trait_setq(**{"_cc_image": ImageStack(data=cc_file.copy())}) del cc_file else: self.trait_setq(**{"_cc_image": None}) assert (np.all(np.any( coefs, axis=1))), "Coefficient matrix filled less than expected." mask = np.where(~np.isnan(shifts).any(axis=1))[0] if len(mask) < shifts.shape[0]: print("Removed {} cross correlations due to bad/missing data?". format(shifts.shape[0] - len(mask))) coefs = coefs[mask, :] shifts = shifts[mask, :] assert (coefs.shape[0] > 0) and ( np.linalg.matrix_rank(coefs) == n_steps - 1), "Something went wrong with coefficient matrix. Not full rank." return shifts, coefs # shifts.shape[0] is n_steps - 1 def rcc( self, shift_max, t_shift, shifts, coefs, ): """ Should probably rename function. Takes cross correlation results and calculates shifts. """ print("{:.2f} s. About to start solving shifts array.".format( time.time() - self._start_time)) # Estimate drift drifts = np.matmul(np.linalg.pinv(coefs), shifts) # print(t_shift) # print(drifts) print("{:.2f} s. Done solving shifts array.".format(time.time() - self._start_time)) if self.method == "RCC": # Calculate residual errors residuals = np.matmul(coefs, drifts) - shifts residuals_dist = np.linalg.norm(residuals, axis=1) # Sort and mask residual errors residuals_arg = np.argsort(-residuals_dist) residuals_arg = residuals_arg[ residuals_dist[residuals_arg] > shift_max] # Remove coefs rows # Descending from largest residuals to small # Only if matrix remains full rank coefs_temp = np.empty_like(coefs) counter = 0 for i, index in enumerate(residuals_arg): coefs_temp[:] = coefs coefs_temp[index, :] = 0 if np.linalg.matrix_rank(coefs_temp) == coefs.shape[1]: coefs[:] = coefs_temp # print("index {} with residual of {} removed".format(index, residuals_dist[index])) counter += 1 else: print( "Could not remove all residuals over shift_max threshold." ) break print("removed {} in total".format(counter)) # Estimate drift again drifts = np.matmul(np.linalg.pinv(coefs), shifts) print("{:.2f} s. RCC completed. Repeated solving shifts array.". format(time.time() - self._start_time)) # pad with 0 drift for first time point drifts = np.pad(drifts, [[1, 0], [0, 0]], 'constant', constant_values=0) return t_shift, drifts def _execute(self, namespace): # dervied versions of RCC need to override this method # 'execute' of this RCC base class is not throughly tested as its use is probably quite limited. # from PYME.util import mProfile self._start_time = time.time() print("Starting drift correction module.") if self.multiprocessing: proccess_count = np.clip(multiprocessing.cpu_count() - 1, 1, None) self._pool = multiprocessing.Pool(processes=proccess_count) # mProfile.profileOn(['localisations.py']) drift_res = self.calc_corr_drift_from_ft_images(self.cache_fft) t_shift, shifts = self.rcc(self.shift_max, *drift_res) # mProfile.profileOff() # mProfile.report() if self.multiprocessing: self._pool.close() self._pool.join() # convert frame-to-frame drift to drift from origin shifts = np.cumsum(shifts, 0) namespace[self.output_drift] = t_shift, shifts
class Binning(CacheCleanupModule): """ Downsample 3D data (mean). X, Y pixels that does't fill a full bin are dropped. Pixels in the 3rd dimension can have a partially filled bin. Inputs ------ inputName : ImageStack Outputs ------- outputName : ImageStack Parameters ---------- x_start : int Starting index in x. x_end : Float Stopping index in x. y_start : Float Starting index in y. y_end : Float Stopping index in y. binsize : Float Bin size. cache_bin : File Use file as disk cache if provided. """ inputName = Input('input') x_start = Int(0) x_end = Int(-1) y_start = Int(0) y_end = Int(-1) # z_start = Int(0) # z_end = Int(-1) binsize = List([1, 1, 1], minlen=3, maxlen=3) cache_bin = File("binning_cache_2.bin") outputName = Output('binned_image') def _execute(self, namespace): self._start_time = time.time() ims = namespace[self.inputName] binsize = np.asarray(self.binsize, dtype=np.int) # print (binsize) # unconventional, end stop in inclusive x_slice = np.arange(ims.data.shape[0] + 1)[slice( self.x_start, self.x_end, 1)] y_slice = np.arange(ims.data.shape[1] + 1)[slice( self.y_start, self.y_end, 1)] x_slice = x_slice[:x_slice.shape[0] // binsize[0] * binsize[0]] y_slice = y_slice[:y_slice.shape[0] // binsize[1] * binsize[1]] # print x_slice, len(x_slice) # print y_slice, len(y_slice) bincounts = np.asarray([ len(x_slice) // binsize[0], len(y_slice) // binsize[1], -(-ims.data.shape[2] // binsize[2]) ], dtype=np.long) x_slice_ind = slice(x_slice[0], x_slice[-1] + 1) y_slice_ind = slice(y_slice[0], y_slice[-1] + 1) # print (bincounts) new_shape = np.stack([bincounts, binsize], -1).flatten() # print(new_shape) # need to wrap this to work for multiply color channel images # binned_image = ims.data[:,:,:].reshape(new_shape) dtype = ims.data[:, :, 0].dtype # print bincounts binned_image = np.memmap(self.cache_bin, dtype=dtype, mode='w+', shape=tuple( np.asarray(bincounts, dtype=np.long))) # print binned_image.shape new_shape_one_chunk = new_shape.copy() new_shape_one_chunk[4] = 1 new_shape_one_chunk[5] = -1 # print new_shape_one_chunk progress = 0.2 * ims.data.shape[2] # print for i, f in enumerate(np.arange(0, ims.data.shape[2], binsize[2])): raw_data_chunk = ims.data[x_slice_ind, y_slice_ind, f:f + binsize[2]].squeeze() binned_image[:, :, i] = raw_data_chunk.reshape(new_shape_one_chunk).mean( (1, 3, 5)).squeeze() if (f + binsize[2] >= progress): binned_image.flush() progress += 0.2 * ims.data.shape[2] print("{:.2f} s. Completed binning {} of {} total images.". format(time.time() - self._start_time, min(f + binsize[2], ims.data.shape[2]), ims.data.shape[2])) # print(type(binned_image)) im = ImageStack(binned_image, titleStub=self.outputName) # print(type(im.data)) im.mdh.copyEntriesFrom(ims.mdh) im.mdh['Parent'] = ims.filename try: ### Metadata must be logged correctly for the measured drift to be applicable to the source image im.mdh['voxelsize.x'] *= binsize[0] im.mdh['voxelsize.y'] *= binsize[1] # im.mdh['voxelsize.z'] *= binsize[2] if 'recipe.binning' in im.mdh.keys(): im.mdh['recipe.binning'] = binsize * im.mdh['recipe.binning'] else: im.mdh['recipe.binning'] = binsize except: pass namespace[self.outputName] = im
class PreprocessingFilter(CacheCleanupModule): """ Optional. Combines a few image processing operations. Only 3D. 1. Applies median filter for denoising. 2. Replaces out of range values to defined values. 3. Applies 2D Tukey filter to dampen potential edge artifacts. Inputs ------ input_name : ImageStack Outputs ------- output_name : ImageStack Parameters ---------- median_filter_size : int Median filter size (``scipy.ndimage.median_filter``). threshold_lower : Float Pixels at this value or lower are replaced. clip_to_lower : Float Pixels below the lower threshold are replaced by this value. threshold_upper : Float Pixels at this value or higher are replaced. clip_to_upper : Float Pixels above the upper threshold are replaced by this value. tukey_size : float Shape parameter for Tukey filter (``scipy.signal.tukey``). cache_clip : File Use file as disk cache if provided. """ input_name = Input('input') threshold_lower = Float(0) clip_to_lower = Float(0) threshold_upper = Float(65535) clip_to_upper = Float(0) median_filter_size = Int(3) tukey_size = Float(0.25) cache_clip = File("clip_cache.bin") output_name = Output('clipped_images') def _execute(self, namespace): self._start_time = time.time() ims = namespace[self.input_name] dtype = ims.data[:, :, 0].dtype # Somewhat arbitrary way to decide on chunk size chunk_size = 100000000 / ims.data.shape[0] / ims.data.shape[ 1] / dtype.itemsize chunk_size = max(1, chunk_size) chunk_size = int(chunk_size) # print chunk_size tukey_mask_x = signal.tukey(ims.data.shape[0], self.tukey_size) tukey_mask_y = signal.tukey(ims.data.shape[1], self.tukey_size) self._tukey_mask_2d = np.multiply( *np.meshgrid(tukey_mask_x, tukey_mask_y, indexing='ij'))[:, :, None] if self.cache_clip == "": raw_data = np.empty(tuple( np.asarray(ims.data.shape[:3], dtype=np.long)), dtype=dtype) else: raw_data = np.memmap(self.cache_clip, dtype=dtype, mode='w+', shape=tuple( np.asarray(ims.data.shape[:3], dtype=np.long))) progress = 0.2 * ims.data.shape[2] for f in np.arange(0, ims.data.shape[2], chunk_size): raw_data[:, :, f:f + chunk_size] = self.applyFilter( ims.data[:, :, f:f + chunk_size]) if (f + chunk_size >= progress): if isinstance(raw_data, np.memmap): raw_data.flush() progress += 0.2 * ims.data.shape[2] print("{:.2f} s. Completed clipping {} of {} total images.". format(time.time() - self._start_time, min(f + chunk_size, ims.data.shape[2]), ims.data.shape[2])) clipped_images = ImageStack(raw_data, mdh=ims.mdh) self.completeMetadata(clipped_images) namespace[self.output_name] = clipped_images def applyFilter(self, data): """ Performs the actual filtering here. """ if self.median_filter_size > 0: data = ndimage.median_filter(data, self.median_filter_size, mode='nearest') data[data >= self.threshold_upper] = self.clip_to_upper data[data <= self.threshold_lower] = self.clip_to_lower data -= self.clip_to_lower if self.tukey_size > 0: data = data * self._tukey_mask_2d return data def completeMetadata(self, im): im.mdh['Processing.Clipping.LowerBounds'] = self.threshold_lower im.mdh['Processing.Clipping.LowerSetValue'] = self.clip_to_lower im.mdh['Processing.Clipping.UpperBounds'] = self.threshold_upper im.mdh['Processing.Clipping.UpperSetValue'] = self.clip_to_upper im.mdh['Processing.Tukey.Size'] = self.tukey_size
class ShiftImage(CacheCleanupModule): """ Performs FT based image shift. Only shift in 2D. Inputs ------ input_image : ImageStack Images with drift. input_drift_interpolator : Returns drift when called with frame number / time. Outputs ------- outputName : ImageStack Parameters ---------- padding_multipler : Int Padding (as multiple of image size) added to the image before shifting to avoid artifacts. cache_image : File Use file as disk cache if provided. """ input_image = Input('input') # input_shift = Input('drift') input_drift_interpolator = Input('drift_interpolator') padding_multipler = Int(1) # ft_cache = File("ft_images.bin") cache_image = File("shifted_image.bin") # image_cache_2 = File("rcc_shifted_image_2.bin") outputName = Output('drift_corrected_image') def _execute(self, namespace): self._start_time = time.time() # try: ## del self._ft_images # del self.image_cache # except: # pass ims = namespace[self.input_image] t_out = np.arange(ims.data.shape[2], dtype=np.float) if 'recipe.binning' in ims.mdh.keys(): t_out *= ims.mdh['recipe.binning'][2] t_out += 0.5 * ims.mdh['recipe.binning'][2] # print t_out dx = namespace[self.input_drift_interpolator][0](t_out) dy = namespace[self.input_drift_interpolator][1](t_out) shifted_images = self.shift_images(ims, np.stack([dx, dy], 1), ims.mdh) namespace[self.outputName] = ImageStack(shifted_images, titleStub=self.outputName, mdh=ims.mdh) def shift_images(self, ims, shifts, mdh): padding = np.stack((ims.data.shape[:2], ) * 2, -1) #2d only padding *= self.padding_multipler padded_image_shape = np.asarray(ims.data.shape[:2], dtype=np.long) + padding.sum((1)) dtype = ims.data[:, :, 0].dtype padded_image = np.zeros(padded_image_shape, dtype=dtype) kx = (np.fft.fftfreq(padded_image_shape[0])) ky = (np.fft.fftfreq(padded_image_shape[1])) # kz = (np.fft.fftfreq(self._ft_images.shape[3])) * 0.5 # kx, ky, kz = np.meshgrid(kx, ky, kz, indexing='ij') kx, ky = np.meshgrid(kx, ky, indexing='ij') images_shape = np.asarray(ims.data.shape[:3], dtype=np.long) images_shape = tuple(images_shape) if self.cache_image == "": shifted_images = np.empty(images_shape) else: shifted_images = np.memmap(self.cache_image, dtype=np.float, mode='w+', shape=images_shape) # print(shifts.shape) # print(kx.shape, ky.shape) # print shifts shifts_in_pixels = np.copy(shifts) try: shifts_in_pixels[:, 0] = shifts[:, 0] / mdh.voxelsize.x shifts_in_pixels[:, 1] = shifts[:, 1] / mdh.voxelsize.y # shifts_in_pixels[:, 2] = shifts[:, 2] / mdh.voxelsize.z # shifts_in_pixels[np.isnan(shifts_in_pixels)] = 0 # print mdh if mdh.voxelsize.units == 'um': # print('um units') shifts_in_pixels /= 1E3 except Exception as e: Warning("Failed at converting drift in pixels to real distances") repr(e) # print shifts_in_pixels for i in np.arange(ims.data.shape[2]): # print i padded_image[padding[0, 0]:padding[0, 0] + ims.data.shape[0], padding[1, 0]:padding[1, 0] + ims.data.shape[1]] = ims.data[:, :, i].squeeze() ft_image = np.fft.fftn(padded_image) data_shifted = shift_image_direct_rough(ft_image, shifts_in_pixels[i], kxy=(kx, ky)) shifted_images[:, :, i] = data_shifted[padding[0, 0]:padding[0, 0] + ims.data.shape[0], padding[1, 0]:padding[1, 0] + ims.data.shape[1]] if ((i + 1) % max(shifted_images.shape[-1] // 5, 1) == 0): if isinstance(shifted_images, np.memmap): shifted_images.flush() print("{:.2f} s. Completed shifting {} of {} total images.". format(time.time() - self._start_time, i + 1, shifted_images.shape[-1])) return shifted_images
class CSVOutputFileBrowse(OutputModule): """ Save tabular data as csv. This module uses a File Browser to set the fileName Parameters ---------- inputName : basestring the name (in the recipe namespace) of the table to save. fileName : File a full path to the file Notes ----- We convert the data to a pandas `DataFrame` and uses the `to_csv` method to save. This version of the output module uses the wx.FileDialog. """ inputName = Input('output') fileName = File('output.csv') saveAs = Button('Save as...') def save(self, namespace, context={}): """ Save recipes output(s) to CSV Parameters ---------- namespace : dict The recipe namespace context : dict Information about the source file to allow pattern substitution to generate the output name. At least 'basedir' (which is the fully resolved directory name in which the input file resides) and 'filestub' (which is the filename without any extension) should be resolved. Returns ------- """ out_filename = self.fileName v = namespace[self.inputName] if not isinstance(v, pd.DataFrame): v = v.toDataFrame() v.to_csv(out_filename) def _saveAs_changed(self): """ Handles the user clicking the 'Save as...' button. """ import wx import os dirname = os.path.dirname(self.fileName) filename = os.path.basename(self.fileName) if not dirname: dirname = os.getcwd() dlg = wx.FileDialog(None, "Save as...", dirname, filename, "*.csv", wx.SAVE|wx.OVERWRITE_PROMPT) result = dlg.ShowModal() inFile = dlg.GetPath() dlg.Destroy() if result == wx.ID_OK: #Save button was pressed self.fileName = inFile @property def default_view(self): from traitsui.api import View, Group, Item, HGroup from PYME.ui.custom_traits_editors import CBEditor editable = self.class_editable_traits() inputs = [tn for tn in editable if tn.startswith('input')] return View( Group(Item('inputName', editor=CBEditor(choices=self._namespace_keys)), HGroup( Item('saveAs', show_label=False), '_', Item('fileName', style='readonly', springy=True) ) ), buttons=['OK'])
class CalculateFRCFromImages(CalculateFRCBase): """ Take a pair of images and calculates the fourier shell/ring correlation (FSC / FRC). Inputs ------ input_image_a : ImageStack First of two images. Outputs ------- output_fft_images_cc : ImageStack Fast Fourier transform original and cross-correlation images. output_frc_dict : dict FSC/FRC results. output_frc_plot : Plot Output plot of the FSC / FRC curve. output_frc_raw : dict Complete FSC/FRC results. Parameters ---------- image_b_path : File File path of the second of the two images. c_channel : int Color channel of the images to use. image_a_z : int Ignored unless flatten_z is True. In which case either select the z plane to use (>=0) or performs a maximum project (<0) for the first image. image_b_z : int Ignored unless flatten_z is True. In which case either select the z plane to use (>=0) or performs a maximum project (<0) for the second image. flatten_z : Bool If enabled ignores z information and only performs a FRC. pre_filter : string Methods to filter the images prior to Fourier transform. frc_smoothing_func : string Methods to smooth the FSC / FRC curve. cubic_smoothing : float Smoothing factor for cubic spline. multiprocessing : Bool Enables multiprocessing. save_path : File (Optional) File path to save output """ input_image_a = Input('input') # image_a_dim = Int(2) # image_a_index = Int(0) # image_b_dim = Int(2) # image_b_index = Int(1) image_b_path = File(info_text="Filepath of image to compare against. Leave blank to compare against currently opened image.") c_channel = Int(0) flatten_z = Bool(True) image_a_z = Int(-1) image_b_z = Int(-1) def execute(self, namespace): self._namespace = namespace import multiprocessing # from PYME.util import mProfile # mProfile.profileOn(["frc.py"]) if self.multiprocessing: proccess_count = np.clip(2, 1, multiprocessing.cpu_count()-1) self._pool = multiprocessing.Pool(processes=proccess_count) # image_pair = self.generate_image_pair(mapped_pipeline) # ims = namespace[self.input_images] image_a = namespace[self.input_image_a] if len(self.image_b_path.strip()) == 0: image_b = image_a else: image_b = ImageStack(filename=self.image_b_path) self._pixel_size_in_nm = np.zeros(3, dtype=np.float) self._pixel_size_in_nm[0] = image_a.mdh.voxelsize.x self._pixel_size_in_nm[1] = image_a.mdh.voxelsize.y try: self._pixel_size_in_nm[2] = image_a.mdh.voxelsize.z except: pass if image_a.mdh.voxelsize.units == 'um': self._pixel_size_in_nm *= 1.E3 # print(self._pixel_size_in_nm) # image_indices = [[self.image_a_dim, self.image_a_index], [self.image_b_dim, self.image_b_index]] # image_slices = list() # for i in xrange(2): # slices = [slice(None, None), slice(None, None)] # for j in xrange(2, image_indices[i][0]+1): # if j == image_indices[i][0]: # slices.append(slice(image_indices[i][1], image_indices[i][1]+1)) # else: # slices.append(slice(None, None)) # image_slices.append(slices) # # image_pair = [ims.data[image_slices[0]].squeeze(), ims.data[image_slices[1]].squeeze()] image_a_data = image_a.data[:,:,:,self.c_channel].squeeze() image_b_data = image_b.data[:,:,:,self.c_channel].squeeze() if self.flatten_z: print("2D mode. Slice if z index >= 0 otherwise max projection") if self.image_a_z >= 0: image_a_data = image_a_data[:,:,self.image_a_z] else: image_a_data = image_a_data.max(2) if self.image_b_z >= 0: image_b_data = image_b_data[:,:,self.image_b_z] else: image_b_data = image_b_data.max(2) # print(np.allclose(image_a_data, image_b_data)) image_pair = [image_a_data, image_b_data] # print(image_pair[0].shape) image_pair = self.preprocess_images(image_pair) frc_res, rawdata = self.calculate_FRC_from_images(image_pair, None) namespace[self.output_frc_dict] = frc_res namespace[self.output_frc_raw] = rawdata if self.multiprocessing: self._pool.close() self._pool.join() # mProfile.profileOff() # mProfile.report() self.save_to_file(namespace)
class CalculateFRCBase(ModuleBase): """ Base class. Refer to derived classes for docstrings. """ pre_filter = Enum(['Tukey_1/8', None]) frc_smoothing_func = Enum(['Cubic Spline', 'Sigmoid', None]) multiprocessing = Bool(True) # plot_graphs = Bool() cubic_smoothing = Float(0.01) save_path = File() # output_fft_image_a = Output('FRC_fft_image_a') # output_fft_image_b = Output('FRC_fft_image_b') output_fft_images_cc = Output('FRC_fft_images_cc') output_frc_dict = Output('FRC_dict') output_frc_plot = Output('FRC_plot') output_frc_raw = Output('FRC_raw') def execute(self): raise Exception("Base class not fully implemented") def preprocess_images(self, image_pair): # pad images to square shape # image_pair = self.pad_images_to_equal_dims(image_pair) dims_length = np.stack([im.shape for im in image_pair], 0) assert np.all([np.all(dims_length[:, i] == dims_length[0, i])for i in range(dims_length.shape[1])]), "Images not the same dimension." # apply filtering to zero near edge of images if self.pre_filter == 'Tukey_1/8': image_pair = self.filter_images_tukey(image_pair, 1./8) elif self.pre_filter == None: pass else: raise Exception() return image_pair # def pad_images_to_equal_dims(self, images): # dims_length = np.stack([im.shape for im in images], 0) # assert np.all([np.all(dims_length[:, i] == dims_length[0, i])for i in xrange(dims_length.shape[1])]), "Images not the same dimension." # # return images # # dims_length = dims_length[0, :] # max_dim = dims_length.max() # # padding = np.empty((dims_length.shape[0], 2), dtype=np.int) # for dim in xrange(dims_length.shape[0]): # total_padding = max_dim - dims_length[dim] # padding[dim] = [total_padding // 2, total_padding - total_padding //2] # # results = list() # for im in images: # results.append(np.pad(im, padding, mode='constant', constant_values=0)) # # return results def filter_images_tukey(self, images, alpha): from scipy.signal import tukey # window = tukey(images[0].shape[0], alpha=alpha) # window_nd = np.prod(np.stack(np.meshgrid(*(window,)*images[0].ndim)), axis=0) windows = [tukey(images[0].shape[i], alpha=alpha) for i in range(images[0].ndim)] window_nd = np.prod(np.stack(np.meshgrid(*windows, indexing='ij')), axis=0) # for im in images: # im *= window_nd return [images[0]*window_nd, images[1]*window_nd] def calculate_FRC_from_images(self, image_pair, mdh): ft_images = list() if self.multiprocessing: results = list() for im in image_pair: results.append(self._pool.apply_async(np.fft.fftn, (im,))) for res in results: ft_images.append(res.get()) del results else: for im in image_pair: ft_images.append(np.fft.fftn(im)) # im_fft_freq = np.fft.fftfreq(image_pair[0].shape[0], self._pixel_size_in_nm) # im_R = np.sqrt(im_fft_freq[:, None]**2 + im_fft_freq[None, :]**2) im_fft_freqs = [np.fft.fftfreq(image_pair[0].shape[i], self._pixel_size_in_nm[i]) for i in range(image_pair[0].ndim)] im_R = np.linalg.norm(np.stack(np.meshgrid(*im_fft_freqs, indexing='ij')), axis=0) im1_fft_power = np.multiply(ft_images[0], np.conj(ft_images[0])) im2_fft_power = np.multiply(ft_images[1], np.conj(ft_images[1])) im12_fft_power = np.multiply(ft_images[0], np.conj(ft_images[1])) ## fft_ims = ImageStack(data=np.stack([np.fft.fftshift(im1_fft_power), ## np.fft.fftshift(im2_fft_power), ## np.fft.fftshift(im12_fft_power)], axis=-1), mdh=mdh) ## self._namespace[self.output_fft_images] = fft_ims # self._namespace[self.output_fft_image_a] = ImageStack(data=np.fft.fftshift(im1_fft_power), titleStub="ImageA_FFT") # self._namespace[self.output_fft_image_b] = ImageStack(data=np.fft.fftshift(im2_fft_power), titleStub="ImageB_FFT") # self._namespace[self.output_fft_images_cc] = ImageStack(data=np.fft.fftshift(im12_fft_power), titleStub="ImageA_Image_B_FFT_CC") try: self._namespace[self.output_fft_images_cc] = ImageStack(data=np.stack([np.atleast_3d(np.fft.fftshift(im1_fft_power)), np.atleast_3d(np.fft.fftshift(im2_fft_power)), np.atleast_3d(np.fft.fftshift(im12_fft_power))], 3), titleStub="ImageA_Image_FFT_CC") # if self.plot_graphs: # from PYME.DSView.dsviewer import ViewIm3D, View3D # # ViewIm3D(self._namespace[self.output_fft_image_a]) # # ViewIm3D(self._namespace[self.output_fft_image_b]) # ViewIm3D(self._namespace[self.output_fft_images_cc]) # # View3D(np.fft.fftshift(im_R)) except Exception as e: print (e) im1_fft_flat_res = CalculateFRCBase.BinData(im_R.flatten(), im1_fft_power.flatten(), statistic='mean', bins=201) im2_fft_flat_res = CalculateFRCBase.BinData(im_R.flatten(), im2_fft_power.flatten(), statistic='mean', bins=201) im12_fft_flat_res = CalculateFRCBase.BinData(im_R.flatten(), im12_fft_power.flatten(), statistic='mean', bins=201) corr = np.real(im12_fft_flat_res.statistic) / np.sqrt(np.abs(im1_fft_flat_res.statistic*im2_fft_flat_res.statistic)) smoothed_frc = self.smooth_frc(im12_fft_flat_res.bin_edges[:-1], corr, self.cubic_smoothing) res, rawdata = self.calculate_threshold(im12_fft_flat_res.bin_edges[:-1], corr, smoothed_frc, im12_fft_flat_res.counts) return res, rawdata def smooth_frc(self, freq, corr, cubic_smoothing): if self.frc_smoothing_func is None: interp_frc = interpolate.interp1d(freq, corr, kind='next', ) return interp_frc elif self.frc_smoothing_func == "Sigmoid": func = CalculateFRCBase.Sigmoid fit_res = optimize.minimize(lambda a, x: np.sum(np.square(func(a[0], a[1], x)-corr)), [1, freq[len(freq)/2]], args=(freq), method='Nelder-Mead') return partial(func, n=fit_res.x[0], c=fit_res.x[1]) elif self.frc_smoothing_func == "Cubic Spline": # smoothed so that average deviation loss is less than 0.2% of original. Somewhat arbitrary but probably not totally unreasonable since FRC is bounded 0 to 1. # interp_frc = interpolate.UnivariateSpline(freq, corr, k=3, s=len(freq)*(0.002*np.var(corr))) # interp_frc = interpolate.UnivariateSpline(freq, corr, k=3, s=(0.05*np.std(corr))) interp_frc = interpolate.UnivariateSpline(freq, corr, k=3, s=cubic_smoothing) return interp_frc def calculate_threshold(self, freq, corr, corr_func, counts): res = dict() fsc_0143 = optimize.minimize(lambda x: np.square(corr_func(x=x)-0.143), freq[np.argmax(corr_func(x=freq)-0.143 < 0)], method='Nelder-Mead') res['frc 1/7'] = 1./fsc_0143.x[0] sigma = 1.0 / np.sqrt(counts*0.5) sigma_spl = interpolate.UnivariateSpline(freq, sigma, k=3, s=0) fsc_3sigma = optimize.minimize(lambda x: np.square(corr_func(x=x)-3.*sigma_spl(x)), freq[np.argmax(corr_func(x=freq)-3.*sigma_spl(freq) < 0)], method='Nelder-Mead') res['frc 3 sigma'] = 1./fsc_3sigma.x[0] # van Heel and Schatz, 2005, Fourier shell correlation threshold criteria half_bit = (0.2071 + 1.9102 / np.sqrt(counts)) / (1.2071 + 0.9102 / np.sqrt(counts)) half_bit_spl = interpolate.UnivariateSpline(freq, half_bit, k=3, s=0) fsc_half_bit = optimize.minimize(lambda x: np.square(corr_func(x=x)-half_bit_spl(x)), freq[np.argmax(corr_func(x=freq)-half_bit_spl(freq) < 0)], method='Nelder-Mead') res['frc half bit'] = 1./fsc_half_bit.x[0] # fsc_max = np.max([fsc_0143.x[0], fsc_2sigma.x[0], fsc_3sigma.x[0], fsc_5sigma.x[0], fsc_half_bit.x[0]]) # axes[1].set_xlim(0, np.min([2*fsc_max, im12_fft_flat_res.bin_edges[-1]])) # if not self.plot_graphs: # ioff() def plot(): frc_text = "" fig, axes = subplots(1,2,figsize=(10,4)) axes[0].plot(freq, corr) axes[0].plot(freq, corr_func(x=freq)) axes[0].axhline(0.143, ls='--', color='red') axes[0].axvline(fsc_0143.x[0], ls='--', color='red', label='1/7') frc_text += "\n1/7: {:.2f} nm".format(1./fsc_0143.x[0]) axes[0].plot(freq, 3*sigma_spl(freq), ls='--', color='pink') axes[0].axvline(fsc_3sigma.x[0], ls='--', color='pink', label='3 sigma') frc_text += "\n3 sigma: {:.2f} nm".format(1./fsc_3sigma.x[0]) axes[0].plot(freq, half_bit_spl(freq), ls='--', color='purple') axes[0].axvline(fsc_half_bit.x[0], ls='--', color='purple', label='1/2 bit') frc_text += "\n1/2 bit: {:.2f} nm".format(1./fsc_half_bit.x[0]) axes[0].legend() # axes[0].set_ylim(None, 1.1) x_ticklocs = axes[0].get_xticks() axes[0].set_xticklabels(["{:.1f}".format(1./i) for i in x_ticklocs]) axes[0].set_ylabel("FSC/FRC") axes[0].set_xlabel("Resol (nm)") axes[1].text(0.5, 0.5, frc_text, horizontalalignment='center', verticalalignment='center', transform=axes[1].transAxes) axes[1].set_axis_off() # if self.plot_graphs: # fig.show() # else: # ion() plot() self._namespace[self.output_frc_plot] = Plot(plot) rawdata = {'freq':freq, 'corr':corr, 'smooth':corr_func(x=freq), '1/7':np.ones_like(freq)/7, '3 sigma':3*sigma_spl(freq), '1/2 bit':half_bit_spl(freq)} return res, rawdata def save_to_file(self, namespace): if self.save_path is not "": try: np.savez_compressed(self.save_path, raw=namespace[self.output_frc_raw], results=namespace[self.output_frc_dict]) except Exception as e: raise e @staticmethod def BinData(indexes, data, statistic='mean', bins=10): # Calculates binned statistics. Supports complex number. if statistic == 'mean': func = np.mean elif statistic == 'sum': func = np.sum class Result(object): statistic = None bin_edges = None counts = None bins = np.linspace(indexes.min(), indexes.max(), bins) binned = np.zeros(len(bins)-1, dtype=data.dtype) counts = np.zeros(len(bins)-1, dtype=np.int) indexes_sort_arg = np.argsort(indexes.flatten()) indexes_sorted = indexes.flatten()[indexes_sort_arg] data_sorted = data.flatten()[indexes_sort_arg] edges_indexes = np.searchsorted(indexes_sorted, bins) for i in range(bins.shape[0]-1): values = data_sorted[edges_indexes[i]:edges_indexes[i+1]] binned[i] = func(values) counts[i] = len(values) res = Result() res.statistic = binned res.bin_edges = bins res.counts = counts return res @staticmethod def Sigmoid(n, c, x): res = 1 - 1 / (1 + np.exp(n*(-x+c))) return res