Example #1
0
class DriftOutput(ModuleBase):
    """
    Save drift data to a file.
        
    Inputs
    ------
    input_name : 
        Drift measured from localization or image dataset.
    
    Outputs
    -------
    output_dummy : None
        Blank output. Required to run correctly.
        
    Parameters
    ----------
    save_path : File
        Filepath to save drift data.    
    """

    input_name = Input('drift')
    save_path = File('drift')
    output_dummy = Output('dummy')  # will not run execute without this

    def execute(self, namespace, context={}):
        #        out_filename = self.filePattern.format(**context)
        out_filename = self.save_path

        tIndex, drift = namespace[self.input_name]

        np.savez_compressed(out_filename, tIndex=tIndex, drift=drift)
        print('saved')
Example #2
0
class FileSource(PointSource):
    file = File()

    #name = Str('Points File')

    def getPoints(self):
        import numpy as np
        return np.load(self.file)
Example #3
0
class FileSource(PointSource):
    file = File()

    #name = Str('Points File')

    def getPoints(self):
        import numpy as np
        return np.load(self.file)

    def genMetaData(self, mdh):
        mdh['GeneratedPoints.Source.Type'] = 'File'
        mdh['GeneratedPoints.Source.FileName'] = self.file
class CombineBeadStacks(ModuleBase):
    """
        Combine multiply bead stacks in the 4th dimension.
        X, Y, Z must be identical.
    """
    
    inputName = Input('dummy')
    
    files = List(File, ['', ''], 2)
    cache = File()
    
    outputName = Output('bead_images')
    
    def execute(self, namespace):
        
        ims = ImageStack(filename=self.files[0])
        dims = np.asarray(ims.data.shape, dtype=np.long)
        dims[3] = 0
        dtype_ = ims.data[:,0,0,0].dtype
        mdh = ims.mdh
        del ims
        
        for fil in self.files:
            ims = ImageStack(filename=fil)
            dims[3] += ims.data.shape[3]
            del ims
        
        if self.cache != '':
            raw_data = np.memmap(self.cache, dtype=dtype_, mode='w+', shape=tuple(dims))
        else:
            raw_data = np.zeros(shape=tuple(dims), dtype=dtype_)
        
        counter = 0
        for fil in self.files:
            ims = ImageStack(filename=fil)
            c_len = ims.data.shape[3]
            data = ims.data[:,:,:,:]
            data.shape += (1,) * (4 - data.ndim)
            raw_data[:,:,:,counter:counter+c_len] = data
            counter += c_len
            del ims
            
        new_mdh = None
        try:
            new_mdh = MetaDataHandler.NestedClassMDHandler(mdh)
            new_mdh["PSFExtraction.SourceFilenames"] = self.files
        except Exception as e:
            print(e)
            
        namespace[self.outputName] = ImageStack(data=raw_data, mdh=new_mdh)
class svmSegment(Filter):
    classifier = File('')
    
    def _loadClassifier(self):
        from PYME.Analysis import svmSegment
        if not '_cf' in dir(self):
            self._cf = svmSegment.svmClassifier(filename=self.classifier)
    
    def applyFilter(self, data, chanNum, frNum, im):
        self._loadClassifier()
        
        return self._cf.classify(data.astype('f'))
    
    def completeMetadata(self, im):
        im.mdh['SVMSegment.classifier'] = self.classifier
Example #6
0
class CNNFilter(Filter):
    """ Use a previously trained Keras neural network to filter the data. Used for learnt de-noising and/or
    deconvolution. Runs prediction piecewise over the image with over-lapping ROIs
    and averages the prediction results.
    
    Notes
    -----
    
    Keras and either Tensorflow or Theano must be installed set up for this module to work. This is not a default
    dependency of python-microscopy as the conda-installable versions don't have GPU support.
    
    """
    model = File('')
    step_size = Int(14)

    def _load_model(self):
        from keras.models import load_model
        if not getattr(self, '_model_name', None) == self.model:
            self._model_name = self.model
            self._model = load_model(
                self._model_name)  #TODO - make cluster-aware

    def applyFilter(self, data, chanNum, frNum, im):
        self._load_model()

        out = np.zeros(data.shape, 'f')

        _, kernel_size_x, kernel_size_y, _ = self._model.input_shape

        scale_factor = 1. / kernel_size_x * kernel_size_y / (float(
            self.step_size)**2)

        for i_x in range(0, out.shape[0] - kernel_size_x, self.step_size):
            for i_y in range(0, out.shape[1] - kernel_size_y, self.step_size):
                d_i = data[i_x:(i_x + kernel_size_x),
                           i_y:(i_y + kernel_size_y)].reshape(
                               1, kernel_size_x, kernel_size_y, 1)
                p = self._model.predict(d_i).reshape(kernel_size_x,
                                                     kernel_size_y)
                out[i_x:(i_x + kernel_size_x),
                    i_y:(i_y + kernel_size_y)] += scale_factor * p

        return out

    def completeMetadata(self, im):
        im.mdh['CNNFilter.model'] = self.model
Example #7
0
class LoadDrift(ModuleBase):
    """
    *Deprecated.*   Use ``LoadDriftandInterp`` instead.
    
    Load drift from a file.
    """

    input_dummy = Input('input')  # breaks GUI without this???
    load_path = File()
    output_drift_raw = Input('drift_raw')
    output_drift_plot = Output('drift_plot')

    def execute(self, namespace):
        data = np.load(self.load_path)
        tIndex = data['tIndex']
        drift = data['drift']
        namespace[self.output_drift_raw] = (tIndex, drift)

        # non essential, only for plotting out drift data
        namespace[self.output_drift_plot] = Plot(
            partial(generate_drift_plot, tIndex, drift))
class RCCDriftCorrectionBase(CacheCleanupModule):
    """    
    Performs drift correction using redundant cross-correlation from
    Wang et al. Optics Express 2014 22:13 (Bo Huang's RCC algorithm).
    Base class for other RCC recipes.
    Can take cached fft input (as filename, not an 'input').
    Only output drift as tuple of time points, and drift amount.
    Currently not registered by itself since not very usefule.
    """

    cache_fft = File("rcc_cache.bin")
    method = Enum(['RCC', 'MCC', 'DCC'])
    # redundant cross-corelation, mean cross-correlation, direct cross-correlation
    shift_max = Float(5)  # nm
    corr_window = Int(5)
    multiprocessing = Bool()
    debug_cor_file = File()

    output_drift = Output('drift')
    output_drift_plot = Output('drift_plot')

    # if debug_cor_file not blank, filled with imagestack of cross correlation
    output_cross_cor = Output('cross_cor')

    def calc_corr_drift_from_ft_images(self, ft_images):
        n_steps = ft_images.shape[0]

        # Matrix equation coefficient matrix
        # Shape can be predetermined based on method
        if self.method == "DCC":
            coefs_size = n_steps - 1
        elif self.corr_window > 0:
            coefs_size = n_steps * self.corr_window - self.corr_window * (
                self.corr_window + 1) // 2
        else:
            coefs_size = n_steps * (n_steps - 1) // 2
        coefs = np.zeros((coefs_size, n_steps - 1))
        shifts = np.zeros((coefs_size, 3))

        counter = 0

        ft_1_cache = list()
        ft_2_cache = list()
        autocor_shift_cache = list()

        #        print self.debug_cor_file
        if not self.debug_cor_file == "":
            cc_file_shape = [
                shifts.shape[0], ft_images.shape[1], ft_images.shape[2],
                (ft_images.shape[3] - 1) * 2
            ]

            # flatten shortest dimension to reduce cross correlation to 2d images for easier debugging
            min_arg = min(enumerate(cc_file_shape[1:]),
                          key=lambda x: x[1])[0] + 1
            cc_file_shape.pop(min_arg)
            cc_file_args = (self.debug_cor_file, np.float,
                            tuple(cc_file_shape))
            cc_file = np.memmap(cc_file_args[0],
                                dtype=cc_file_args[1],
                                mode="w+",
                                shape=cc_file_args[2])
            #            del cc_file
            cc_args = zip(range(shifts.shape[0]),
                          (cc_file_args, ) * shifts.shape[0])
        else:
            cc_args = (None, ) * shifts.shape[0]

        # For each ft image, calculate correlation
        for i in np.arange(0, n_steps - 1):
            if self.method == "DCC" and i > 0:
                break

            ft_1 = ft_images[i, :, :]

            autocor_shift = calc_shift(ft_1, ft_1)

            for j in np.arange(i + 1, n_steps):
                if (self.method != "DCC") and (self.corr_window > 0) and (
                        j - i > self.corr_window):
                    break

                ft_2 = ft_images[j, :, :]

                coefs[counter, i:j] = 1

                # if multiprocessing, use cache when defined
                if self.multiprocessing:
                    # if reading ft_images from cache, replace ft_1 and ft_2 with their indices
                    if not self.cache_fft == "":
                        ft_1 = i
                        ft_2 = j

                    ft_1_cache.append(ft_1)
                    ft_2_cache.append(ft_2)
                    autocor_shift_cache.append(autocor_shift)
                else:
                    shifts[counter, :] = calc_shift(ft_1, ft_2, autocor_shift,
                                                    None, cc_args[counter])

                    if ((counter + 1) % max(coefs_size // 5, 1) == 0):
                        print(
                            "{:.2f} s. Completed calculating {} of {} total shifts."
                            .format(time.time() - self._start_time,
                                    counter + 1, coefs_size))

                counter += 1

        if self.multiprocessing:
            args = zip(
                range(len(autocor_shift_cache)), ft_1_cache, ft_2_cache,
                autocor_shift_cache,
                len(ft_1_cache) *
                ((self.cache_fft, ft_images.dtype, ft_images.shape), ),
                cc_args)
            for i, (j, res) in enumerate(
                    self._pool.imap_unordered(calc_shift_helper, args)):
                shifts[j, ] = res

                if ((i + 1) % max(coefs_size // 5, 1) == 0):
                    print(
                        "{:.2f} s. Completed calculating {} of {} total shifts."
                        .format(time.time() - self._start_time, i + 1,
                                coefs_size))

        print("{:.2f} s. Finished calculating all shifts.".format(
            time.time() - self._start_time))
        print("{:,} bytes".format(coefs.nbytes))
        print("{:,} bytes".format(shifts.nbytes))

        if not self.debug_cor_file == "":
            # move time axis for ImageStack
            cc_file = np.moveaxis(cc_file, 0, 2)
            self.trait_setq(**{"_cc_image": ImageStack(data=cc_file.copy())})
            del cc_file
        else:
            self.trait_setq(**{"_cc_image": None})

        assert (np.all(np.any(
            coefs, axis=1))), "Coefficient matrix filled less than expected."

        mask = np.where(~np.isnan(shifts).any(axis=1))[0]
        if len(mask) < shifts.shape[0]:
            print("Removed {} cross correlations due to bad/missing data?".
                  format(shifts.shape[0] - len(mask)))
            coefs = coefs[mask, :]
            shifts = shifts[mask, :]

        assert (coefs.shape[0] > 0) and (
            np.linalg.matrix_rank(coefs) == n_steps -
            1), "Something went wrong with coefficient matrix. Not full rank."

        return shifts, coefs  # shifts.shape[0] is n_steps - 1

    def rcc(
        self,
        shift_max,
        t_shift,
        shifts,
        coefs,
    ):
        """
            Should probably rename function.
            Takes cross correlation results and calculates shifts.
        """

        print("{:.2f} s. About to start solving shifts array.".format(
            time.time() - self._start_time))

        # Estimate drift
        drifts = np.matmul(np.linalg.pinv(coefs), shifts)
        #        print(t_shift)
        #        print(drifts)

        print("{:.2f} s. Done solving shifts array.".format(time.time() -
                                                            self._start_time))

        if self.method == "RCC":

            # Calculate residual errors
            residuals = np.matmul(coefs, drifts) - shifts
            residuals_dist = np.linalg.norm(residuals, axis=1)

            # Sort and mask residual errors
            residuals_arg = np.argsort(-residuals_dist)
            residuals_arg = residuals_arg[
                residuals_dist[residuals_arg] > shift_max]

            # Remove coefs rows
            # Descending from largest residuals to small
            # Only if matrix remains full rank
            coefs_temp = np.empty_like(coefs)
            counter = 0
            for i, index in enumerate(residuals_arg):
                coefs_temp[:] = coefs
                coefs_temp[index, :] = 0
                if np.linalg.matrix_rank(coefs_temp) == coefs.shape[1]:
                    coefs[:] = coefs_temp
                    #                print("index {} with residual of {} removed".format(index, residuals_dist[index]))
                    counter += 1
                else:
                    print(
                        "Could not remove all residuals over shift_max threshold."
                    )
                    break
            print("removed {} in total".format(counter))

            # Estimate drift again
            drifts = np.matmul(np.linalg.pinv(coefs), shifts)

            print("{:.2f} s. RCC completed. Repeated solving shifts array.".
                  format(time.time() - self._start_time))

        # pad with 0 drift for first time point
        drifts = np.pad(drifts, [[1, 0], [0, 0]],
                        'constant',
                        constant_values=0)

        return t_shift, drifts

    def _execute(self, namespace):
        # dervied versions of RCC need to override this method
        # 'execute' of this RCC base class is not throughly tested as its use is probably quite limited.

        #        from PYME.util import mProfile

        self._start_time = time.time()
        print("Starting drift correction module.")

        if self.multiprocessing:
            proccess_count = np.clip(multiprocessing.cpu_count() - 1, 1, None)
            self._pool = multiprocessing.Pool(processes=proccess_count)


#        mProfile.profileOn(['localisations.py'])

        drift_res = self.calc_corr_drift_from_ft_images(self.cache_fft)
        t_shift, shifts = self.rcc(self.shift_max, *drift_res)
        #        mProfile.profileOff()
        #        mProfile.report()

        if self.multiprocessing:
            self._pool.close()
            self._pool.join()

        # convert frame-to-frame drift to drift from origin
        shifts = np.cumsum(shifts, 0)

        namespace[self.output_drift] = t_shift, shifts
class Binning(CacheCleanupModule):
    """
    Downsample 3D data (mean).
    X, Y pixels that does't fill a full bin are dropped.
    Pixels in the 3rd dimension can have a partially filled bin.
        
    Inputs
    ------
    inputName : ImageStack
    
    Outputs
    -------
    outputName : ImageStack
    
    Parameters
    ----------
    x_start : int
        Starting index in x.
    x_end : Float
        Stopping index in x.
    y_start : Float
        Starting index in y.
    y_end : Float
        Stopping index in y.
    binsize : Float
        Bin size.
    cache_bin : File
        Use file as disk cache if provided.
    """

    inputName = Input('input')
    x_start = Int(0)
    x_end = Int(-1)
    y_start = Int(0)
    y_end = Int(-1)
    #    z_start = Int(0)
    #    z_end = Int(-1)
    binsize = List([1, 1, 1], minlen=3, maxlen=3)
    cache_bin = File("binning_cache_2.bin")
    outputName = Output('binned_image')

    def _execute(self, namespace):
        self._start_time = time.time()
        ims = namespace[self.inputName]

        binsize = np.asarray(self.binsize, dtype=np.int)
        #        print (binsize)

        # unconventional, end stop in inclusive
        x_slice = np.arange(ims.data.shape[0] + 1)[slice(
            self.x_start, self.x_end, 1)]
        y_slice = np.arange(ims.data.shape[1] + 1)[slice(
            self.y_start, self.y_end, 1)]
        x_slice = x_slice[:x_slice.shape[0] // binsize[0] * binsize[0]]
        y_slice = y_slice[:y_slice.shape[0] // binsize[1] * binsize[1]]
        #        print x_slice, len(x_slice)
        #        print y_slice, len(y_slice)
        bincounts = np.asarray([
            len(x_slice) // binsize[0],
            len(y_slice) // binsize[1], -(-ims.data.shape[2] // binsize[2])
        ],
                               dtype=np.long)

        x_slice_ind = slice(x_slice[0], x_slice[-1] + 1)
        y_slice_ind = slice(y_slice[0], y_slice[-1] + 1)

        #        print (bincounts)
        new_shape = np.stack([bincounts, binsize], -1).flatten()
        #        print(new_shape)

        # need to wrap this to work for multiply color channel images
        #        binned_image = ims.data[:,:,:].reshape(new_shape)
        dtype = ims.data[:, :, 0].dtype

        #        print bincounts
        binned_image = np.memmap(self.cache_bin,
                                 dtype=dtype,
                                 mode='w+',
                                 shape=tuple(
                                     np.asarray(bincounts, dtype=np.long)))
        #        print binned_image.shape

        new_shape_one_chunk = new_shape.copy()
        new_shape_one_chunk[4] = 1
        new_shape_one_chunk[5] = -1
        #        print new_shape_one_chunk
        progress = 0.2 * ims.data.shape[2]
        #        print
        for i, f in enumerate(np.arange(0, ims.data.shape[2], binsize[2])):
            raw_data_chunk = ims.data[x_slice_ind, y_slice_ind,
                                      f:f + binsize[2]].squeeze()

            binned_image[:, :,
                         i] = raw_data_chunk.reshape(new_shape_one_chunk).mean(
                             (1, 3, 5)).squeeze()

            if (f + binsize[2] >= progress):
                binned_image.flush()
                progress += 0.2 * ims.data.shape[2]
                print("{:.2f} s. Completed binning {} of {} total images.".
                      format(time.time() - self._start_time,
                             min(f + binsize[2], ims.data.shape[2]),
                             ims.data.shape[2]))


#        print(type(binned_image))
        im = ImageStack(binned_image, titleStub=self.outputName)
        #        print(type(im.data))
        im.mdh.copyEntriesFrom(ims.mdh)
        im.mdh['Parent'] = ims.filename
        try:
            ### Metadata must be logged correctly for the measured drift to be applicable to the source image
            im.mdh['voxelsize.x'] *= binsize[0]
            im.mdh['voxelsize.y'] *= binsize[1]
            #            im.mdh['voxelsize.z'] *= binsize[2]
            if 'recipe.binning' in im.mdh.keys():
                im.mdh['recipe.binning'] = binsize * im.mdh['recipe.binning']
            else:
                im.mdh['recipe.binning'] = binsize
        except:
            pass

        namespace[self.outputName] = im
class PreprocessingFilter(CacheCleanupModule):
    """
    Optional. Combines a few image processing operations. Only 3D.
    
    1. Applies median filter for denoising.
    
    2. Replaces out of range values to defined values.
    
    3. Applies 2D Tukey filter to dampen potential edge artifacts.
        
    Inputs
    ------
    input_name : ImageStack
    
    Outputs
    -------
    output_name : ImageStack
    
    Parameters
    ----------
    median_filter_size : int
        Median filter size (``scipy.ndimage.median_filter``).
    threshold_lower : Float
        Pixels at this value or lower are replaced.
    clip_to_lower : Float
        Pixels below the lower threshold are replaced by this value.
    threshold_upper : Float
        Pixels at this value or higher are replaced.
    clip_to_upper : Float
        Pixels above the upper threshold are replaced by this value.
    tukey_size : float
        Shape parameter for Tukey filter (``scipy.signal.tukey``).
    cache_clip : File
        Use file as disk cache if provided.
    """
    input_name = Input('input')
    threshold_lower = Float(0)
    clip_to_lower = Float(0)
    threshold_upper = Float(65535)
    clip_to_upper = Float(0)
    median_filter_size = Int(3)
    tukey_size = Float(0.25)
    cache_clip = File("clip_cache.bin")
    output_name = Output('clipped_images')

    def _execute(self, namespace):
        self._start_time = time.time()
        ims = namespace[self.input_name]

        dtype = ims.data[:, :, 0].dtype

        # Somewhat arbitrary way to decide on chunk size
        chunk_size = 100000000 / ims.data.shape[0] / ims.data.shape[
            1] / dtype.itemsize
        chunk_size = max(1, chunk_size)
        chunk_size = int(chunk_size)
        #        print chunk_size

        tukey_mask_x = signal.tukey(ims.data.shape[0], self.tukey_size)
        tukey_mask_y = signal.tukey(ims.data.shape[1], self.tukey_size)
        self._tukey_mask_2d = np.multiply(
            *np.meshgrid(tukey_mask_x, tukey_mask_y, indexing='ij'))[:, :,
                                                                     None]

        if self.cache_clip == "":
            raw_data = np.empty(tuple(
                np.asarray(ims.data.shape[:3], dtype=np.long)),
                                dtype=dtype)
        else:
            raw_data = np.memmap(self.cache_clip,
                                 dtype=dtype,
                                 mode='w+',
                                 shape=tuple(
                                     np.asarray(ims.data.shape[:3],
                                                dtype=np.long)))

        progress = 0.2 * ims.data.shape[2]
        for f in np.arange(0, ims.data.shape[2], chunk_size):
            raw_data[:, :, f:f + chunk_size] = self.applyFilter(
                ims.data[:, :, f:f + chunk_size])

            if (f + chunk_size >= progress):
                if isinstance(raw_data, np.memmap):
                    raw_data.flush()
                progress += 0.2 * ims.data.shape[2]
                print("{:.2f} s. Completed clipping {} of {} total images.".
                      format(time.time() - self._start_time,
                             min(f + chunk_size, ims.data.shape[2]),
                             ims.data.shape[2]))

        clipped_images = ImageStack(raw_data, mdh=ims.mdh)
        self.completeMetadata(clipped_images)

        namespace[self.output_name] = clipped_images

    def applyFilter(self, data):
        """
            Performs the actual filtering here.
        """
        if self.median_filter_size > 0:
            data = ndimage.median_filter(data,
                                         self.median_filter_size,
                                         mode='nearest')
        data[data >= self.threshold_upper] = self.clip_to_upper
        data[data <= self.threshold_lower] = self.clip_to_lower
        data -= self.clip_to_lower
        if self.tukey_size > 0:
            data = data * self._tukey_mask_2d
        return data

    def completeMetadata(self, im):
        im.mdh['Processing.Clipping.LowerBounds'] = self.threshold_lower
        im.mdh['Processing.Clipping.LowerSetValue'] = self.clip_to_lower
        im.mdh['Processing.Clipping.UpperBounds'] = self.threshold_upper
        im.mdh['Processing.Clipping.UpperSetValue'] = self.clip_to_upper
        im.mdh['Processing.Tukey.Size'] = self.tukey_size
class ShiftImage(CacheCleanupModule):
    """
    Performs FT based image shift. Only shift in 2D.    
        
    Inputs
    ------
    input_image : ImageStack
        Images with drift.
    input_drift_interpolator : 
        Returns drift when called with frame number / time.
    
    Outputs
    -------
    outputName : ImageStack
    
    Parameters
    ----------
    padding_multipler : Int
        Padding (as multiple of image size) added to the image before shifting to avoid artifacts.
    cache_image : File
        Use file as disk cache if provided.
    """

    input_image = Input('input')
    #    input_shift = Input('drift')
    input_drift_interpolator = Input('drift_interpolator')
    padding_multipler = Int(1)

    #    ft_cache = File("ft_images.bin")
    cache_image = File("shifted_image.bin")
    #    image_cache_2 = File("rcc_shifted_image_2.bin")
    outputName = Output('drift_corrected_image')

    def _execute(self, namespace):
        self._start_time = time.time()
        #        try:
        ##            del self._ft_images
        #            del self.image_cache
        #        except:
        #            pass

        ims = namespace[self.input_image]

        t_out = np.arange(ims.data.shape[2], dtype=np.float)

        if 'recipe.binning' in ims.mdh.keys():
            t_out *= ims.mdh['recipe.binning'][2]
            t_out += 0.5 * ims.mdh['recipe.binning'][2]
#        print t_out

        dx = namespace[self.input_drift_interpolator][0](t_out)
        dy = namespace[self.input_drift_interpolator][1](t_out)

        shifted_images = self.shift_images(ims, np.stack([dx, dy], 1), ims.mdh)

        namespace[self.outputName] = ImageStack(shifted_images,
                                                titleStub=self.outputName,
                                                mdh=ims.mdh)

    def shift_images(self, ims, shifts, mdh):

        padding = np.stack((ims.data.shape[:2], ) * 2, -1)  #2d only
        padding *= self.padding_multipler

        padded_image_shape = np.asarray(ims.data.shape[:2],
                                        dtype=np.long) + padding.sum((1))

        dtype = ims.data[:, :, 0].dtype
        padded_image = np.zeros(padded_image_shape, dtype=dtype)

        kx = (np.fft.fftfreq(padded_image_shape[0]))
        ky = (np.fft.fftfreq(padded_image_shape[1]))
        #        kz = (np.fft.fftfreq(self._ft_images.shape[3])) * 0.5
        #        kx, ky, kz = np.meshgrid(kx, ky, kz, indexing='ij')
        kx, ky = np.meshgrid(kx, ky, indexing='ij')

        images_shape = np.asarray(ims.data.shape[:3], dtype=np.long)
        images_shape = tuple(images_shape)

        if self.cache_image == "":
            shifted_images = np.empty(images_shape)
        else:
            shifted_images = np.memmap(self.cache_image,
                                       dtype=np.float,
                                       mode='w+',
                                       shape=images_shape)

#        print(shifts.shape)
#        print(kx.shape, ky.shape)

#        print shifts
        shifts_in_pixels = np.copy(shifts)

        try:
            shifts_in_pixels[:, 0] = shifts[:, 0] / mdh.voxelsize.x
            shifts_in_pixels[:, 1] = shifts[:, 1] / mdh.voxelsize.y
            #            shifts_in_pixels[:, 2] = shifts[:, 2] / mdh.voxelsize.z

            #            shifts_in_pixels[np.isnan(shifts_in_pixels)] = 0

            #            print mdh
            if mdh.voxelsize.units == 'um':
                #                print('um units')
                shifts_in_pixels /= 1E3
        except Exception as e:
            Warning("Failed at converting drift in pixels to real distances")
            repr(e)


#        print shifts_in_pixels

        for i in np.arange(ims.data.shape[2]):
            #            print i

            padded_image[padding[0, 0]:padding[0, 0] + ims.data.shape[0],
                         padding[1, 0]:padding[1, 0] +
                         ims.data.shape[1]] = ims.data[:, :, i].squeeze()

            ft_image = np.fft.fftn(padded_image)

            data_shifted = shift_image_direct_rough(ft_image,
                                                    shifts_in_pixels[i],
                                                    kxy=(kx, ky))

            shifted_images[:, :,
                           i] = data_shifted[padding[0, 0]:padding[0, 0] +
                                             ims.data.shape[0],
                                             padding[1, 0]:padding[1, 0] +
                                             ims.data.shape[1]]

            if ((i + 1) % max(shifted_images.shape[-1] // 5, 1) == 0):
                if isinstance(shifted_images, np.memmap):
                    shifted_images.flush()
                print("{:.2f} s. Completed shifting {} of {} total images.".
                      format(time.time() - self._start_time, i + 1,
                             shifted_images.shape[-1]))

        return shifted_images
Example #12
0
class CSVOutputFileBrowse(OutputModule):
    """
    Save tabular data as csv. This module uses a File Browser to set the fileName

    Parameters
    ----------

    inputName : basestring
        the name (in the recipe namespace) of the table to save.

    fileName : File
        a full path to the file

    Notes
    -----

    We convert the data to a pandas `DataFrame` and uses the `to_csv`
    method to save.

    This version of the output module uses the wx.FileDialog.

    """

    inputName = Input('output')
    fileName = File('output.csv')
    saveAs = Button('Save as...')
    
    def save(self, namespace, context={}):
        """
        Save recipes output(s) to CSV

        Parameters
        ----------
        namespace : dict
            The recipe namespace
        context : dict
            Information about the source file to allow pattern substitution to generate the output name. At least
            'basedir' (which is the fully resolved directory name in which the input file resides) and
            'filestub' (which is the filename without any extension) should be resolved.

        Returns
        -------

        """

        out_filename = self.fileName
        v = namespace[self.inputName]

        if not isinstance(v, pd.DataFrame):
            v = v.toDataFrame()
                
        v.to_csv(out_filename)

    def _saveAs_changed(self):
        """ Handles the user clicking the 'Save as...' button.
        """
        import wx
        import os
        dirname = os.path.dirname(self.fileName)
        filename = os.path.basename(self.fileName)
        if not dirname:
            dirname = os.getcwd()
        dlg = wx.FileDialog(None, "Save as...", dirname, filename, "*.csv",
                            wx.SAVE|wx.OVERWRITE_PROMPT)
        result = dlg.ShowModal()
        inFile = dlg.GetPath()
        dlg.Destroy()

        if result == wx.ID_OK:          #Save button was pressed
            self.fileName = inFile
        
    @property
    def default_view(self):
        from traitsui.api import View, Group, Item, HGroup
        from PYME.ui.custom_traits_editors import CBEditor

        editable = self.class_editable_traits()
        inputs = [tn for tn in editable if tn.startswith('input')]
        return View(
            Group(Item('inputName', editor=CBEditor(choices=self._namespace_keys)),
                  HGroup(
                      Item('saveAs', show_label=False),
                      '_',
                      Item('fileName', style='readonly', springy=True)
                  )
            ), buttons=['OK'])
Example #13
0
class CalculateFRCFromImages(CalculateFRCBase):
    """
    Take a pair of images and calculates the fourier shell/ring correlation (FSC / FRC).
    
    Inputs
    ------
    input_image_a : ImageStack
        First of two images.
        
    Outputs
    -------
    output_fft_images_cc : ImageStack
        Fast Fourier transform original and cross-correlation images.
    output_frc_dict : dict
        FSC/FRC results.
    output_frc_plot : Plot
        Output plot of the FSC / FRC curve.
    output_frc_raw : dict
        Complete FSC/FRC results.
    
    Parameters
    ----------
    image_b_path : File
        File path of the second of the two images.
    c_channel : int
        Color channel of the images to use.
    image_a_z : int
        Ignored unless flatten_z is True. In which case either select the z plane to use (>=0) or performs a maximum project (<0) for the first image.
    image_b_z : int
        Ignored unless flatten_z is True. In which case either select the z plane to use (>=0) or performs a maximum project (<0) for the second image.
    flatten_z : Bool
        If enabled ignores z information and only performs a FRC.
    pre_filter : string
        Methods to filter the images prior to Fourier transform.
    frc_smoothing_func : string
        Methods to smooth the FSC / FRC curve.
    cubic_smoothing  : float
        Smoothing factor for cubic spline.
    multiprocessing : Bool
        Enables multiprocessing.
    save_path : File
        (Optional) File path to save output
    
    """
    
    input_image_a = Input('input')
#    image_a_dim = Int(2)
#    image_a_index = Int(0)
#    image_b_dim = Int(2)
#    image_b_index = Int(1)
    image_b_path = File(info_text="Filepath of image to compare against. Leave blank to compare against currently opened image.")
    c_channel = Int(0)
    flatten_z = Bool(True)
    image_a_z = Int(-1)
    image_b_z = Int(-1)
    
    def execute(self, namespace):
        self._namespace = namespace
        import multiprocessing
#        from PYME.util import mProfile        
#        mProfile.profileOn(["frc.py"])
        
        if self.multiprocessing:
            proccess_count = np.clip(2, 1, multiprocessing.cpu_count()-1)
            self._pool = multiprocessing.Pool(processes=proccess_count)        
       
#        image_pair = self.generate_image_pair(mapped_pipeline)
#        ims = namespace[self.input_images]
        image_a = namespace[self.input_image_a]

        if len(self.image_b_path.strip()) == 0:
            image_b = image_a
        else:
            image_b = ImageStack(filename=self.image_b_path)
        
        self._pixel_size_in_nm = np.zeros(3, dtype=np.float)
        self._pixel_size_in_nm[0] = image_a.mdh.voxelsize.x
        self._pixel_size_in_nm[1] = image_a.mdh.voxelsize.y
        try:
            self._pixel_size_in_nm[2] = image_a.mdh.voxelsize.z
        except:
            pass
        if image_a.mdh.voxelsize.units == 'um':
            self._pixel_size_in_nm *= 1.E3
            # print(self._pixel_size_in_nm)
        
#        image_indices = [[self.image_a_dim, self.image_a_index], [self.image_b_dim, self.image_b_index]]
#        image_slices = list()
#        for i in xrange(2):
#            slices = [slice(None, None), slice(None, None)]
#            for j in xrange(2, image_indices[i][0]+1):
#                if j == image_indices[i][0]:
#                    slices.append(slice(image_indices[i][1], image_indices[i][1]+1))
#                else:
#                    slices.append(slice(None, None))
#            image_slices.append(slices)
#        
#        image_pair = [ims.data[image_slices[0]].squeeze(), ims.data[image_slices[1]].squeeze()]
        image_a_data = image_a.data[:,:,:,self.c_channel].squeeze()
        image_b_data = image_b.data[:,:,:,self.c_channel].squeeze()
        if self.flatten_z:
            print("2D mode. Slice if z index >= 0 otherwise max projection")
            if self.image_a_z >= 0:
                image_a_data = image_a_data[:,:,self.image_a_z]
            else:
                image_a_data = image_a_data.max(2)
            if self.image_b_z >= 0:
                image_b_data = image_b_data[:,:,self.image_b_z]
            else:
                image_b_data = image_b_data.max(2)
#            print(np.allclose(image_a_data, image_b_data))
        image_pair = [image_a_data, image_b_data]
#        print(image_pair[0].shape)
        
        image_pair = self.preprocess_images(image_pair)            
       
        frc_res, rawdata = self.calculate_FRC_from_images(image_pair, None)
        
        namespace[self.output_frc_dict] = frc_res
        namespace[self.output_frc_raw] = rawdata
        
        if self.multiprocessing:
            self._pool.close()
            self._pool.join()
        
#        mProfile.profileOff()
#        mProfile.report()
        
        self.save_to_file(namespace)
Example #14
0
class CalculateFRCBase(ModuleBase):
    """
    Base class. Refer to derived classes for docstrings.
    """
    pre_filter = Enum(['Tukey_1/8', None])
    frc_smoothing_func = Enum(['Cubic Spline', 'Sigmoid', None])
    multiprocessing =  Bool(True)
#    plot_graphs = Bool()
    cubic_smoothing  = Float(0.01)
    
    save_path = File()
    
#    output_fft_image_a = Output('FRC_fft_image_a')
#    output_fft_image_b = Output('FRC_fft_image_b')
    output_fft_images_cc = Output('FRC_fft_images_cc')
    output_frc_dict = Output('FRC_dict')
    output_frc_plot = Output('FRC_plot')
    output_frc_raw = Output('FRC_raw')
    
    def execute(self):
        raise Exception("Base class not fully implemented")
    
    def preprocess_images(self, image_pair):
        # pad images to square shape
#        image_pair = self.pad_images_to_equal_dims(image_pair)
        
        dims_length = np.stack([im.shape for im in image_pair], 0)
        assert np.all([np.all(dims_length[:, i] == dims_length[0, i])for i in range(dims_length.shape[1])]), "Images not the same dimension."
        
        # apply filtering to zero near edge of images
        if self.pre_filter == 'Tukey_1/8':
            image_pair = self.filter_images_tukey(image_pair, 1./8)
        elif self.pre_filter == None:
            pass
        else:
            raise Exception()
            
        return image_pair
        
#    def pad_images_to_equal_dims(self, images):
#        dims_length = np.stack([im.shape for im in images], 0)
#        assert np.all([np.all(dims_length[:, i] == dims_length[0, i])for i in xrange(dims_length.shape[1])]), "Images not the same dimension."
#        
#        return images
#        
#        dims_length = dims_length[0, :]        
#        max_dim = dims_length.max()
#        
#        padding = np.empty((dims_length.shape[0], 2), dtype=np.int)
#        for dim in xrange(dims_length.shape[0]):
#            total_padding = max_dim - dims_length[dim]
#            padding[dim] = [total_padding // 2, total_padding - total_padding //2]
#            
#        results = list()
#        for im in images:
#            results.append(np.pad(im, padding, mode='constant', constant_values=0))
#        
#        return results

    def filter_images_tukey(self, images, alpha):
        from scipy.signal import tukey
        
#        window = tukey(images[0].shape[0], alpha=alpha)
#        window_nd = np.prod(np.stack(np.meshgrid(*(window,)*images[0].ndim)), axis=0)
        windows = [tukey(images[0].shape[i], alpha=alpha) for i in range(images[0].ndim)]
        window_nd = np.prod(np.stack(np.meshgrid(*windows, indexing='ij')), axis=0)
        
#        for im in images:
#            im *= window_nd
        
        return [images[0]*window_nd, images[1]*window_nd]

    def calculate_FRC_from_images(self, image_pair, mdh):
        ft_images = list()
        if self.multiprocessing:
            results = list()
            for im in image_pair:
                results.append(self._pool.apply_async(np.fft.fftn, (im,)))            
            for res in results:
                ft_images.append(res.get())
            del results
        else:
            for im in image_pair:
                ft_images.append(np.fft.fftn(im))
        
#        im_fft_freq = np.fft.fftfreq(image_pair[0].shape[0], self._pixel_size_in_nm)
#        im_R = np.sqrt(im_fft_freq[:, None]**2 + im_fft_freq[None, :]**2)
        im_fft_freqs = [np.fft.fftfreq(image_pair[0].shape[i], self._pixel_size_in_nm[i]) for i in range(image_pair[0].ndim)]
        im_R = np.linalg.norm(np.stack(np.meshgrid(*im_fft_freqs, indexing='ij')), axis=0)

        im1_fft_power = np.multiply(ft_images[0], np.conj(ft_images[0]))
        im2_fft_power = np.multiply(ft_images[1], np.conj(ft_images[1]))        
        im12_fft_power = np.multiply(ft_images[0], np.conj(ft_images[1]))                
        
##        fft_ims = ImageStack(data=np.stack([np.fft.fftshift(im1_fft_power),
##                                            np.fft.fftshift(im2_fft_power),
##                                            np.fft.fftshift(im12_fft_power)], axis=-1), mdh=mdh)
##        self._namespace[self.output_fft_images] = fft_ims
#        self._namespace[self.output_fft_image_a] = ImageStack(data=np.fft.fftshift(im1_fft_power), titleStub="ImageA_FFT")
#        self._namespace[self.output_fft_image_b] = ImageStack(data=np.fft.fftshift(im2_fft_power), titleStub="ImageB_FFT")
#        self._namespace[self.output_fft_images_cc] = ImageStack(data=np.fft.fftshift(im12_fft_power), titleStub="ImageA_Image_B_FFT_CC")
        
        try:
            self._namespace[self.output_fft_images_cc] = ImageStack(data=np.stack([np.atleast_3d(np.fft.fftshift(im1_fft_power)),
                   np.atleast_3d(np.fft.fftshift(im2_fft_power)),
                   np.atleast_3d(np.fft.fftshift(im12_fft_power))], 3), titleStub="ImageA_Image_FFT_CC")
            
#            if self.plot_graphs:
#                from PYME.DSView.dsviewer import ViewIm3D, View3D
#    #            ViewIm3D(self._namespace[self.output_fft_image_a])
#    #            ViewIm3D(self._namespace[self.output_fft_image_b])
#                ViewIm3D(self._namespace[self.output_fft_images_cc])
#    #            View3D(np.fft.fftshift(im_R))
            
        except Exception as e:
            print (e)
            
        
        im1_fft_flat_res = CalculateFRCBase.BinData(im_R.flatten(), im1_fft_power.flatten(), statistic='mean', bins=201)
        im2_fft_flat_res = CalculateFRCBase.BinData(im_R.flatten(), im2_fft_power.flatten(), statistic='mean', bins=201)
        im12_fft_flat_res = CalculateFRCBase.BinData(im_R.flatten(), im12_fft_power.flatten(), statistic='mean', bins=201)
        
        corr = np.real(im12_fft_flat_res.statistic) / np.sqrt(np.abs(im1_fft_flat_res.statistic*im2_fft_flat_res.statistic))
        
        smoothed_frc = self.smooth_frc(im12_fft_flat_res.bin_edges[:-1], corr, self.cubic_smoothing)
        
        res, rawdata = self.calculate_threshold(im12_fft_flat_res.bin_edges[:-1], corr, smoothed_frc, im12_fft_flat_res.counts)
        
        return res, rawdata
    
    def smooth_frc(self, freq, corr, cubic_smoothing):
        if self.frc_smoothing_func is None:
            interp_frc = interpolate.interp1d(freq, corr, kind='next', )
            return interp_frc
        
        elif self.frc_smoothing_func == "Sigmoid":
            func = CalculateFRCBase.Sigmoid
            fit_res = optimize.minimize(lambda a, x: np.sum(np.square(func(a[0], a[1], x)-corr)), [1, freq[len(freq)/2]], args=(freq), method='Nelder-Mead')

            return partial(func, n=fit_res.x[0], c=fit_res.x[1])
        elif self.frc_smoothing_func == "Cubic Spline":
            # smoothed so that average deviation loss is less than 0.2% of original. Somewhat arbitrary but probably not totally unreasonable since FRC is bounded 0 to 1.
#            interp_frc = interpolate.UnivariateSpline(freq, corr, k=3, s=len(freq)*(0.002*np.var(corr)))
#            interp_frc = interpolate.UnivariateSpline(freq, corr, k=3, s=(0.05*np.std(corr)))
            interp_frc = interpolate.UnivariateSpline(freq, corr, k=3, s=cubic_smoothing)

            return interp_frc            
    
    def calculate_threshold(self, freq, corr, corr_func, counts):
        res = dict()
        
        fsc_0143 = optimize.minimize(lambda x: np.square(corr_func(x=x)-0.143), freq[np.argmax(corr_func(x=freq)-0.143 < 0)], method='Nelder-Mead')
        res['frc 1/7'] = 1./fsc_0143.x[0]
        
        sigma = 1.0 / np.sqrt(counts*0.5)
        sigma_spl = interpolate.UnivariateSpline(freq, sigma, k=3, s=0)
        fsc_3sigma = optimize.minimize(lambda x: np.square(corr_func(x=x)-3.*sigma_spl(x)), freq[np.argmax(corr_func(x=freq)-3.*sigma_spl(freq) < 0)], method='Nelder-Mead')
        res['frc 3 sigma'] = 1./fsc_3sigma.x[0]        
        
        # van Heel and Schatz, 2005, Fourier shell correlation threshold criteria    
        half_bit = (0.2071 + 1.9102 / np.sqrt(counts)) / (1.2071 + 0.9102 / np.sqrt(counts))
        half_bit_spl = interpolate.UnivariateSpline(freq, half_bit, k=3, s=0)
        fsc_half_bit = optimize.minimize(lambda x: np.square(corr_func(x=x)-half_bit_spl(x)), freq[np.argmax(corr_func(x=freq)-half_bit_spl(freq) < 0)], method='Nelder-Mead')
        res['frc half bit'] = 1./fsc_half_bit.x[0]
        
#        fsc_max = np.max([fsc_0143.x[0], fsc_2sigma.x[0], fsc_3sigma.x[0], fsc_5sigma.x[0], fsc_half_bit.x[0]])
#        axes[1].set_xlim(0, np.min([2*fsc_max, im12_fft_flat_res.bin_edges[-1]]))
        
#        if not self.plot_graphs:
#            ioff()
        
        def plot():
            frc_text = ""
            fig, axes = subplots(1,2,figsize=(10,4))
            axes[0].plot(freq, corr)
            axes[0].plot(freq, corr_func(x=freq))
            
            axes[0].axhline(0.143, ls='--', color='red')
            axes[0].axvline(fsc_0143.x[0], ls='--', color='red', label='1/7')
            frc_text += "\n1/7:  {:.2f} nm".format(1./fsc_0143.x[0])
     
            axes[0].plot(freq, 3*sigma_spl(freq), ls='--', color='pink')
            axes[0].axvline(fsc_3sigma.x[0], ls='--', color='pink', label='3 sigma')
            frc_text += "\n3 sigma:  {:.2f} nm".format(1./fsc_3sigma.x[0])
     
            axes[0].plot(freq, half_bit_spl(freq), ls='--', color='purple')
            axes[0].axvline(fsc_half_bit.x[0], ls='--', color='purple', label='1/2 bit')
    
            frc_text += "\n1/2 bit:  {:.2f} nm".format(1./fsc_half_bit.x[0])
            
            axes[0].legend()            
    #        axes[0].set_ylim(None, 1.1)
            
            x_ticklocs = axes[0].get_xticks()
            axes[0].set_xticklabels(["{:.1f}".format(1./i) for i in x_ticklocs])
            axes[0].set_ylabel("FSC/FRC")
            axes[0].set_xlabel("Resol (nm)")
            
            axes[1].text(0.5, 0.5, frc_text, horizontalalignment='center', verticalalignment='center', transform=axes[1].transAxes)
            axes[1].set_axis_off()
        
#        if self.plot_graphs:
#            fig.show()
#        else:
#            ion()
            
        plot()
        
        self._namespace[self.output_frc_plot] = Plot(plot)
        
        rawdata = {'freq':freq, 'corr':corr, 'smooth':corr_func(x=freq), '1/7':np.ones_like(freq)/7, '3 sigma':3*sigma_spl(freq), '1/2 bit':half_bit_spl(freq)}
        
        return res, rawdata
    
    def save_to_file(self, namespace):
        if self.save_path is not "":
            try:
                np.savez_compressed(self.save_path, raw=namespace[self.output_frc_raw], results=namespace[self.output_frc_dict])
            except Exception as e:
                raise e
    
    @staticmethod
    def BinData(indexes, data, statistic='mean', bins=10):
        # Calculates binned statistics. Supports complex number.
        if statistic == 'mean':
            func = np.mean
        elif statistic == 'sum':
            func = np.sum
        
        class Result(object):
            statistic = None
            bin_edges = None
            counts = None
        
        bins = np.linspace(indexes.min(), indexes.max(), bins)
        binned = np.zeros(len(bins)-1, dtype=data.dtype)
        counts = np.zeros(len(bins)-1, dtype=np.int)
        
        indexes_sort_arg = np.argsort(indexes.flatten())
        indexes_sorted = indexes.flatten()[indexes_sort_arg]
        data_sorted = data.flatten()[indexes_sort_arg]
        edges_indexes = np.searchsorted(indexes_sorted, bins)
        
        for i in range(bins.shape[0]-1):
            values = data_sorted[edges_indexes[i]:edges_indexes[i+1]]
            binned[i] = func(values)
            counts[i] = len(values)
            
        res = Result()
        res.statistic = binned
        res.bin_edges = bins
        res.counts = counts
        return res
    
    @staticmethod
    def Sigmoid(n, c, x):
        res = 1 - 1 / (1 + np.exp(n*(-x+c)))
        return res