Пример #1
0
class ExtractChannel(ModuleBase):
    """Extract one channel from an image"""
    inputName = Input('input')
    outputName = Output('filtered_image')

    channelToExtract = Int(0)

    def _pickChannel(self, image):
        chan = image.data[:, :, :, self.channelToExtract]

        im = ImageStack(chan, titleStub='Filtered Image')
        im.mdh.copyEntriesFrom(image.mdh)
        try:
            im.mdh['ChannelNames'] = [
                image.names[self.channelToExtract],
            ]
        except (KeyError, AttributeError):
            logger.warn("Error setting channel name")

        im.mdh['Parent'] = image.filename

        return im

    def execute(self, namespace):
        namespace[self.outputName] = self._pickChannel(
            namespace[self.inputName])
Пример #2
0
class HistByID(ModuleBase):
    """Plot histogram of a column by ID"""
    inputName = Input('measurements')
    IDkey = CStr('objectID')
    histkey = CStr('qIndex')
    outputName = Output('outGraph')
    nbins = Int(50)
    minval = Float(float('nan'))
    maxval = Float(float('nan'))

    def execute(self, namespace):
        import math
        meas = namespace[self.inputName]
        ids = meas[self.IDkey]

        uid, valsu = uniqueByID(ids, meas[self.histkey])
        if math.isnan(self.minval):
            minv = valsu.min()
        else:
            minv = self.minval
        if math.isnan(self.maxval):
            maxv = valsu.max()
        else:
            maxv = self.maxval

        import matplotlib.pyplot as plt
        plt.figure()
        plt.hist(valsu, self.nbins, range=(minv, maxv))
        plt.xlabel(self.histkey)

    @property
    def _key_choices(self):
        #try and find the available column names
        try:
            return sorted(self._parent.namespace[self.inputName].keys())
        except:
            return []

    @property
    def default_view(self):
        from traitsui.api import View, Group, Item
        from PYME.ui.custom_traits_editors import CBEditor

        return View(Item('inputName',
                         editor=CBEditor(choices=self._namespace_keys)),
                    Item('_'),
                    Item('IDkey', editor=CBEditor(choices=self._key_choices)),
                    Item('histkey',
                         editor=CBEditor(choices=self._key_choices)),
                    Item('nbins'),
                    Item('minval'),
                    Item('maxval'),
                    Item('_'),
                    Item('outputName'),
                    buttons=['OK'])
class SlidingWindowMAD(ModuleBase):
    """
    Using a rolling window along the time (/ z) dimension, calculate the median-absolute deviation (MAD)

    Parameters
    ----------
    series: PYME.IO.image.ImageStack
    time_window_size: int
        Size of window to use in rolling-median and standard deviation calculations

    Returns
    -------
    output: PYME.IO.image.ImageStack
        MAD calculated within the rolling window. Note that the window size is kept constant, so output will be a
        shorter series than the input.

    Notes
    -----
    Currently only set up for single-color data
    """

    input = Input('input')

    time_window_size = Int(10)

    process_frames_individually = False

    output = Output('MAD')

    def execute(self, namespace):
        from scipy.stats import median_absolute_deviation
        series = namespace[self.input]

        steps = range(series.data.shape[2] - self.time_window_size)
        output = np.empty(
            (series.data.shape[0], series.data.shape[1], len(steps)),
            dtype=series.data[:, :, 0, 0].dtype)  # only 1 color for now

        for ti in steps:
            output[:, :, ti] = median_absolute_deviation(
                series.data[:, :, ti:ti + self.time_window_size],
                scale=1,
                axis=2)

        out = image.ImageStack(data=output)
        out.mdh = MetaDataHandler.NestedClassMDHandler()
        try:
            out.mdh.copyEntriesFrom(series.mdh)
        except AttributeError:
            pass
        out.mdh['Analysis.FilterSpikes.TimeWindowSize'] = self.time_window_size

        namespace[self.output] = out
Пример #4
0
class InterpolateDrift(ModuleBase):
    """
    Creates a spline interpolator from drift data. (``scipy.interpolate.UnivariateSpline``)
        
    Inputs
    ------
    input_drift_raw : Tuple of arrays
        Drift measured from localization or image dataset.
    
    Outputs
    -------
    output_drift_interpolator :
        Drift interpolator. Returns drift when called with frame number / time.
    output_drift_plot : Plot
        Plot of the original and interpolated drift.
        
    Parameters
    ----------
    degree_of_spline : int
        Degree of the smoothing spline.
    smoothing_factor : float
        Smoothing factor.
    """

    #    input_dummy = Input('input') # breaks GUI without this???
    #    load_path = File()
    degree_of_spline = Int(3)  # 1 for linear, 3 for cubic
    smoothing_factor = Float(
        -1)  # 0 for no smoothing. set to negative for UnivariateSpline defulat
    input_drift_raw = Input('drift_raw')
    output_drift_interpolator = Output('drift_interpolator')
    output_drift_plot = Output('drift_plot')

    def execute(self, namespace):
        #        data = np.load(self.load_path)
        #        tIndex = data['tIndex']
        #        drift = data['drift']
        #        namespace[self.output_drift_raw] = (tIndex, drift)
        tIndex, drift = namespace[self.input_drift_raw]

        spl = interpolate_drift(tIndex, drift, self.degree_of_spline,
                                self.smoothing_factor)

        namespace[self.output_drift_interpolator] = spl

        #        # non essential, only for plotting out drift data
        namespace[self.output_drift_plot] = Plot(
            partial(generate_drift_plot, tIndex, drift, spl))
        namespace[self.output_drift_plot].plot()
Пример #5
0
class CNNFilter(Filter):
    """ Use a previously trained Keras neural network to filter the data. Used for learnt de-noising and/or
    deconvolution. Runs prediction piecewise over the image with over-lapping ROIs
    and averages the prediction results.
    
    Notes
    -----
    
    Keras and either Tensorflow or Theano must be installed set up for this module to work. This is not a default
    dependency of python-microscopy as the conda-installable versions don't have GPU support.
    
    """
    model = FileOrURI('')
    step_size = Int(14)

    def _load_model(self):
        from keras.models import load_model
        if not getattr(self, '_model_name', None) == self.model:
            self._model_name = self.model
            with unifiedIO.local_or_temp_filename(self._model_name) as fn:
                self._model = load_model(fn)

    def applyFilter(self, data, chanNum, frNum, im):
        self._load_model()

        out = np.zeros(data.shape, 'f')

        _, kernel_size_x, kernel_size_y, _ = self._model.input_shape

        scale_factor = 1. / kernel_size_x * kernel_size_y / (float(
            self.step_size)**2)

        for i_x in range(0, out.shape[0] - kernel_size_x, self.step_size):
            for i_y in range(0, out.shape[1] - kernel_size_y, self.step_size):
                d_i = data[i_x:(i_x + kernel_size_x),
                           i_y:(i_y + kernel_size_y)].reshape(
                               1, kernel_size_x, kernel_size_y, 1)
                p = self._model.predict(d_i).reshape(kernel_size_x,
                                                     kernel_size_y)
                out[i_x:(i_x + kernel_size_x),
                    i_y:(i_y + kernel_size_y)] += scale_factor * p

        return out

    def completeMetadata(self, im):
        im.mdh['CNNFilter.model'] = self.model
Пример #6
0
class PointFeatureBase(ModuleBase):
    """
    common base class for feature extraction routines - implements normalisation and PCA routines
    """

    outputColumnName = CStr('features')
    columnForEachFeature = Bool(
        False
    )  #if true, outputs a column for each feature - useful for visualising

    normalise = Bool(True)  #subtract mean and divide by std. deviation

    PCA = Bool(
        True
    )  # reduce feature dimensionality by performing PCA - TODO - should this be a separate module and be chained instead?
    PCA_components = Int(3)  # 0 = same dimensionality as features

    def _process_features(self, data, features):
        from PYME.IO import tabular
        out = tabular.MappingFilter(data)
        out.mdh = getattr(data, 'mdh', None)

        if self.normalise:
            features = features - features.mean(0)[None, :]
            features = features / features.std(0)[None, :]

        if self.PCA:
            from sklearn.decomposition import PCA

            pca = PCA(n_components=(
                self.PCA_components if self.PCA_components > 0 else None
            )).fit(features)
            features = pca.transform(features)

            out.pca = pca  #save the pca object just in case we want to look at what the principle components are (this is hacky)

        out.addColumn(self.outputColumnName, features)

        if self.columnForEachFeature:
            for i in range(features.shape[1]):
                out.addColumn('feat_%d' % i, features[:, i])

        return out
Пример #7
0
class PointFeaturesPairwiseDist(PointFeatureBase):
    """
    Create a feature vector for each point in a point-cloud using a histogram of it's distances to all other points
    
    """
    inputLocalisations = Input('localisations')
    outputName = Output('features')

    binWidth = Float(100.)  # width of the bins in nm
    numBins = Int(20)  #number of bins (starting at 0)
    threeD = Bool(True)

    normaliseRelativeDensity = Bool(
        False
    )  # divide by the sum of all radial bins. If not performed, the first principle component will likely be average density

    def execute(self, namespace):
        from PYME.Analysis.points import DistHist
        points = namespace[self.inputLocalisations]

        if self.threeD:
            x, y, z = points['x'], points['y'], points['z']
            f = np.array([
                DistHist.distanceHistogram3D(x[i], y[i], z[i], x, y, z,
                                             self.numBins, self.binWidth)
                for i in xrange(len(x))
            ])
        else:
            x, y = points['x'], points['y']
            f = np.array([
                DistHist.distanceHistogram(x[i], y[i], x, y, self.numBins,
                                           self.binWidth)
                for i in xrange(len(x))
            ])

        namespace[self.outputName] = self._process_features(points, f)
Пример #8
0
class DBSCANClustering2(ModuleBase):
    """
    Performs DBSCAN clustering on input dictionary

    Parameters
    ----------

        searchRadius: search radius for clustering
        minPtsForCore: number of points within SearchRadius required for a given point to be considered a core point

    Notes
    -----

    See `sklearn.cluster.dbscan` for more details about the underlying algorithm and parameter meanings.

    """
    import multiprocessing
    inputName = Input('filtered')

    columns = ListStr(['x', 'y', 'z'])
    searchRadius = Float()
    minClumpSize = Int()
    numberOfJobs = Int(
        max(multiprocessing.cpu_count() - 1,
            1))  # this is a feature of the latest dbscan in scipy

    outputName = Output('dbscanClustered')

    def execute(self, namespace):
        from sklearn.cluster import dbscan

        inp = namespace[self.inputName]
        mapped = tabular.mappingFilter(inp)

        # Note that sklearn gives unclustered points label of -1, and first value starts at 0.
        try:
            core_samp, dbLabels = dbscan(np.vstack(
                [inp[k] for k in self.columns]).T,
                                         self.searchRadius,
                                         self.minClumpSize,
                                         n_jobs=self.numberOfJobs)
            multiproc = True
        except:
            core_samp, dbLabels = dbscan(
                np.vstack([inp[k] for k in self.columns]).T, self.searchRadius,
                self.minClumpSize)
            multiproc = False

        if multiproc:
            logger.info('using dbscan multiproc version')
        else:
            logger.info('falling back to dbscan single-threaded version')

            # shift dbscan labels up by one to match existing convention that a clumpID of 0 corresponds to unclumped
        mapped.addColumn('dbscanClumpID', dbLabels + 1)

        # propogate metadata, if present
        try:
            mapped.mdh = inp.mdh
        except AttributeError:
            pass

        namespace[self.outputName] = mapped

    @property
    def hide_in_overview(self):
        return ['columns']
Пример #9
0
class LoadDriftandInterp(ModuleBase):
    """
    Loads drift data from file(s) and use them to create a spline interpolator (``scipy.interpolate.UnivariateSpline``).
        
    Inputs
    ------
    input_dummy : None
       Blank input. Required to run correctly.
    
    Outputs
    -------
    output_drift_interpolator :
        Drift interpolator. Returns drift when called with frame number / time.
    output_drift_plot : Plot
        Plot of the original and interpolated drift.
        
    Parameters
    ----------
    load_paths : list of File
        List of files to load.
    degree_of_spline : int
        Degree of the smoothing spline.
    smoothing_factor : float
        Smoothing factor.
    """

    input_dummy = Input('input')  # breaks GUI without this???
    #    load_path = File()
    load_paths = List(File, [""], 1)
    degree_of_spline = Int(3)  # 1 for linear, 3 for cubic
    smoothing_factor = Float(
        -1)  # 0 for no smoothing. set to negative for UnivariateSpline defulat
    #    input_drift_raw = Input('drift_raw')
    output_drift_interpolator = Output('drift_interpolator')
    output_drift_plot = Output('drift_plot')

    #    output_drift_raw= Input('drift_raw')

    def execute(self, namespace):
        spl_array = list()
        t_min = np.inf
        t_max = 0
        tIndexes = list()
        drifts = list()
        for fil in self.load_paths:
            data = np.load(fil)
            tIndex = data['tIndex']
            t_min = min(t_min, tIndex[0])
            t_max = max(t_max, tIndex[-1])
            drift = data['drift']

            tIndexes.append(tIndex)
            drifts.append(drift)

            spl = interpolate_drift(tIndex, drift, self.degree_of_spline,
                                    self.smoothing_factor)
            spl_array.append(spl)


#        print(len(spl_array))
#        print(spl_array[0])
        spl_array = zip(*spl_array)

        #        print(len(spl_final))
        #        print(spl_final[0])
        def spl_method(funcs, t):
            return np.sum([f(t) for f in funcs], axis=0)

        spl_combined = list()
        for spl in spl_array:
            #            print(spl)
            #            spl_combined.append(lambda x: np.sum([f(x) for f in spl], axis=0))
            spl_combined.append(partial(spl_method, spl))

        namespace[self.output_drift_interpolator] = spl_combined

        # non essential, only for plotting out drift data
        namespace[self.output_drift_plot] = Plot(
            partial(generate_drift_plot, tIndexes, drifts, spl_combined))
        namespace[self.output_drift_plot].plot()
class ClusteringByLabel(ModuleBase):
    """

    Parameters
    ----------
    input_name : Input
        PYME.IO.ImageStack
    mask : Input
        PYME.IO.ImageStack. Optional mask to only calculate metrics

    Returns
    -------
    output_name = Output


    Notes
    -----

    """

    input_name = Input('input')

    mask = Input('')
    excitation_start_frame = Int(10)

    output_vom = CStr('')
    output_mean_pre_excitation = CStr('')
    output_name = Output('cluster_metrics')

    def execute(self, namespace):

        series = namespace[self.input_name]

        # squeeze down from 4D
        data = series.data[:, :, :].squeeze()
        if self.mask == '':  # not the most memory efficient, but make a mask
            logger.debug(
                'No mask provided to ClusteringByLabel, analyzing full image')
            mask = np.ones((data.shape[0], data.shape[1]), int)
        else:
            mask = namespace[self.mask].data[:, :, :].squeeze()
        # toss any negative labels, as well as the zero label (per PYME clustering schema).
        labels = sorted(list(set(np.clip(np.unique(mask), 0, None)) - {0}))
        print(labels)

        n_labels = len(labels)

        # calculate the Variance_t over Mean_t
        var = np.var(data[:, :, self.excitation_start_frame:], axis=2)
        mean = np.mean(data[:, :, self.excitation_start_frame:], axis=2)

        variance_over_mean = var / mean
        if np.isnan(variance_over_mean).any():
            logger.error('Variance over mean contains NaN, see %s' %
                         series.filename)

        mean_pre_excitation = np.mean(data[:, :, :self.excitation_start_frame],
                                      axis=2)

        cluster_metric_mean = np.zeros(n_labels)
        mean_before_excitation = np.zeros(n_labels)

        for li in range(n_labels):
            # everything is 2D at this point
            label_mask = mask == labels[li]
            cluster_metric_mean[li] = np.mean(variance_over_mean[label_mask])
            mean_before_excitation[li] = np.mean(
                mean_pre_excitation[label_mask])

        res = tabular.DictSource({
            'variance_over_mean': cluster_metric_mean,
            'mean_intensity_over_first_10_frames': mean_before_excitation,
            'labels': np.array(labels)
        })
        try:
            res.mdh = series.mdh
        except AttributeError:
            res.mdh = None

        namespace[self.output_name] = res

        if self.output_vom != '':
            namespace[self.output_vom] = image.ImageStack(
                data=variance_over_mean, mdh=res.mdh)

        if self.output_mean_pre_excitation != '':
            namespace[self.output_mean_pre_excitation] = image.ImageStack(
                data=mean_pre_excitation, mdh=res.mdh)
class PreprocessingFilter(CacheCleanupModule):
    """
    Optional. Combines a few image processing operations. Only 3D.
    
    1. Applies median filter for denoising.
    
    2. Replaces out of range values to defined values.
    
    3. Applies 2D Tukey filter to dampen potential edge artifacts.
        
    Inputs
    ------
    input_name : ImageStack
    
    Outputs
    -------
    output_name : ImageStack
    
    Parameters
    ----------
    median_filter_size : int
        Median filter size (``scipy.ndimage.median_filter``).
    threshold_lower : Float
        Pixels at this value or lower are replaced.
    clip_to_lower : Float
        Pixels below the lower threshold are replaced by this value.
    threshold_upper : Float
        Pixels at this value or higher are replaced.
    clip_to_upper : Float
        Pixels above the upper threshold are replaced by this value.
    tukey_size : float
        Shape parameter for Tukey filter (``scipy.signal.tukey``).
    cache_clip : File
        Use file as disk cache if provided.
    """
    input_name = Input('input')
    threshold_lower = Float(0)
    clip_to_lower = Float(0)
    threshold_upper = Float(65535)
    clip_to_upper = Float(0)
    median_filter_size = Int(3)
    tukey_size = Float(0.25)
    cache_clip = File("clip_cache.bin")
    output_name = Output('clipped_images')

    def _execute(self, namespace):
        self._start_time = time.time()
        ims = namespace[self.input_name]

        dtype = ims.data[:, :, 0].dtype

        # Somewhat arbitrary way to decide on chunk size
        chunk_size = 100000000 / ims.data.shape[0] / ims.data.shape[
            1] / dtype.itemsize
        chunk_size = max(1, chunk_size)
        chunk_size = int(chunk_size)
        #        print chunk_size

        tukey_mask_x = signal.tukey(ims.data.shape[0], self.tukey_size)
        tukey_mask_y = signal.tukey(ims.data.shape[1], self.tukey_size)
        self._tukey_mask_2d = np.multiply(
            *np.meshgrid(tukey_mask_x, tukey_mask_y, indexing='ij'))[:, :,
                                                                     None]

        if self.cache_clip == "":
            raw_data = np.empty(tuple(
                np.asarray(ims.data.shape[:3], dtype=np.long)),
                                dtype=dtype)
        else:
            raw_data = np.memmap(self.cache_clip,
                                 dtype=dtype,
                                 mode='w+',
                                 shape=tuple(
                                     np.asarray(ims.data.shape[:3],
                                                dtype=np.long)))

        progress = 0.2 * ims.data.shape[2]
        for f in np.arange(0, ims.data.shape[2], chunk_size):
            raw_data[:, :, f:f + chunk_size] = self.applyFilter(
                ims.data[:, :, f:f + chunk_size])

            if (f + chunk_size >= progress):
                if isinstance(raw_data, np.memmap):
                    raw_data.flush()
                progress += 0.2 * ims.data.shape[2]
                print("{:.2f} s. Completed clipping {} of {} total images.".
                      format(time.time() - self._start_time,
                             min(f + chunk_size, ims.data.shape[2]),
                             ims.data.shape[2]))

        clipped_images = ImageStack(raw_data, mdh=ims.mdh)
        self.completeMetadata(clipped_images)

        namespace[self.output_name] = clipped_images

    def applyFilter(self, data):
        """
            Performs the actual filtering here.
        """
        if self.median_filter_size > 0:
            data = ndimage.median_filter(data,
                                         self.median_filter_size,
                                         mode='nearest')
        data[data >= self.threshold_upper] = self.clip_to_upper
        data[data <= self.threshold_lower] = self.clip_to_lower
        data -= self.clip_to_lower
        if self.tukey_size > 0:
            data = data * self._tukey_mask_2d
        return data

    def completeMetadata(self, im):
        im.mdh['Processing.Clipping.LowerBounds'] = self.threshold_lower
        im.mdh['Processing.Clipping.LowerSetValue'] = self.clip_to_lower
        im.mdh['Processing.Clipping.UpperBounds'] = self.threshold_upper
        im.mdh['Processing.Clipping.UpperSetValue'] = self.clip_to_upper
        im.mdh['Processing.Tukey.Size'] = self.tukey_size
class RCCDriftCorrectionBase(CacheCleanupModule):
    """    
    Performs drift correction using redundant cross-correlation from
    Wang et al. Optics Express 2014 22:13 (Bo Huang's RCC algorithm).
    Base class for other RCC recipes.
    Can take cached fft input (as filename, not an 'input').
    Only output drift as tuple of time points, and drift amount.
    Currently not registered by itself since not very usefule.
    """

    cache_fft = File("rcc_cache.bin")
    method = Enum(['RCC', 'MCC', 'DCC'])
    # redundant cross-corelation, mean cross-correlation, direct cross-correlation
    shift_max = Float(5)  # nm
    corr_window = Int(5)
    multiprocessing = Bool()
    debug_cor_file = File()

    output_drift = Output('drift')
    output_drift_plot = Output('drift_plot')

    # if debug_cor_file not blank, filled with imagestack of cross correlation
    output_cross_cor = Output('cross_cor')

    def calc_corr_drift_from_ft_images(self, ft_images):
        n_steps = ft_images.shape[0]

        # Matrix equation coefficient matrix
        # Shape can be predetermined based on method
        if self.method == "DCC":
            coefs_size = n_steps - 1
        elif self.corr_window > 0:
            coefs_size = n_steps * self.corr_window - self.corr_window * (
                self.corr_window + 1) // 2
        else:
            coefs_size = n_steps * (n_steps - 1) // 2
        coefs = np.zeros((coefs_size, n_steps - 1))
        shifts = np.zeros((coefs_size, 3))

        counter = 0

        ft_1_cache = list()
        ft_2_cache = list()
        autocor_shift_cache = list()

        #        print self.debug_cor_file
        if not self.debug_cor_file == "":
            cc_file_shape = [
                shifts.shape[0], ft_images.shape[1], ft_images.shape[2],
                (ft_images.shape[3] - 1) * 2
            ]

            # flatten shortest dimension to reduce cross correlation to 2d images for easier debugging
            min_arg = min(enumerate(cc_file_shape[1:]),
                          key=lambda x: x[1])[0] + 1
            cc_file_shape.pop(min_arg)
            cc_file_args = (self.debug_cor_file, np.float,
                            tuple(cc_file_shape))
            cc_file = np.memmap(cc_file_args[0],
                                dtype=cc_file_args[1],
                                mode="w+",
                                shape=cc_file_args[2])
            #            del cc_file
            cc_args = zip(range(shifts.shape[0]),
                          (cc_file_args, ) * shifts.shape[0])
        else:
            cc_args = (None, ) * shifts.shape[0]

        # For each ft image, calculate correlation
        for i in np.arange(0, n_steps - 1):
            if self.method == "DCC" and i > 0:
                break

            ft_1 = ft_images[i, :, :]

            autocor_shift = calc_shift(ft_1, ft_1)

            for j in np.arange(i + 1, n_steps):
                if (self.method != "DCC") and (self.corr_window > 0) and (
                        j - i > self.corr_window):
                    break

                ft_2 = ft_images[j, :, :]

                coefs[counter, i:j] = 1

                # if multiprocessing, use cache when defined
                if self.multiprocessing:
                    # if reading ft_images from cache, replace ft_1 and ft_2 with their indices
                    if not self.cache_fft == "":
                        ft_1 = i
                        ft_2 = j

                    ft_1_cache.append(ft_1)
                    ft_2_cache.append(ft_2)
                    autocor_shift_cache.append(autocor_shift)
                else:
                    shifts[counter, :] = calc_shift(ft_1, ft_2, autocor_shift,
                                                    None, cc_args[counter])

                    if ((counter + 1) % max(coefs_size // 5, 1) == 0):
                        print(
                            "{:.2f} s. Completed calculating {} of {} total shifts."
                            .format(time.time() - self._start_time,
                                    counter + 1, coefs_size))

                counter += 1

        if self.multiprocessing:
            args = zip(
                range(len(autocor_shift_cache)), ft_1_cache, ft_2_cache,
                autocor_shift_cache,
                len(ft_1_cache) *
                ((self.cache_fft, ft_images.dtype, ft_images.shape), ),
                cc_args)
            for i, (j, res) in enumerate(
                    self._pool.imap_unordered(calc_shift_helper, args)):
                shifts[j, ] = res

                if ((i + 1) % max(coefs_size // 5, 1) == 0):
                    print(
                        "{:.2f} s. Completed calculating {} of {} total shifts."
                        .format(time.time() - self._start_time, i + 1,
                                coefs_size))

        print("{:.2f} s. Finished calculating all shifts.".format(
            time.time() - self._start_time))
        print("{:,} bytes".format(coefs.nbytes))
        print("{:,} bytes".format(shifts.nbytes))

        if not self.debug_cor_file == "":
            # move time axis for ImageStack
            cc_file = np.moveaxis(cc_file, 0, 2)
            self.trait_setq(**{"_cc_image": ImageStack(data=cc_file.copy())})
            del cc_file
        else:
            self.trait_setq(**{"_cc_image": None})

        assert (np.all(np.any(
            coefs, axis=1))), "Coefficient matrix filled less than expected."

        mask = np.where(~np.isnan(shifts).any(axis=1))[0]
        if len(mask) < shifts.shape[0]:
            print("Removed {} cross correlations due to bad/missing data?".
                  format(shifts.shape[0] - len(mask)))
            coefs = coefs[mask, :]
            shifts = shifts[mask, :]

        assert (coefs.shape[0] > 0) and (
            np.linalg.matrix_rank(coefs) == n_steps -
            1), "Something went wrong with coefficient matrix. Not full rank."

        return shifts, coefs  # shifts.shape[0] is n_steps - 1

    def rcc(
        self,
        shift_max,
        t_shift,
        shifts,
        coefs,
    ):
        """
            Should probably rename function.
            Takes cross correlation results and calculates shifts.
        """

        print("{:.2f} s. About to start solving shifts array.".format(
            time.time() - self._start_time))

        # Estimate drift
        drifts = np.matmul(np.linalg.pinv(coefs), shifts)
        #        print(t_shift)
        #        print(drifts)

        print("{:.2f} s. Done solving shifts array.".format(time.time() -
                                                            self._start_time))

        if self.method == "RCC":

            # Calculate residual errors
            residuals = np.matmul(coefs, drifts) - shifts
            residuals_dist = np.linalg.norm(residuals, axis=1)

            # Sort and mask residual errors
            residuals_arg = np.argsort(-residuals_dist)
            residuals_arg = residuals_arg[
                residuals_dist[residuals_arg] > shift_max]

            # Remove coefs rows
            # Descending from largest residuals to small
            # Only if matrix remains full rank
            coefs_temp = np.empty_like(coefs)
            counter = 0
            for i, index in enumerate(residuals_arg):
                coefs_temp[:] = coefs
                coefs_temp[index, :] = 0
                if np.linalg.matrix_rank(coefs_temp) == coefs.shape[1]:
                    coefs[:] = coefs_temp
                    #                print("index {} with residual of {} removed".format(index, residuals_dist[index]))
                    counter += 1
                else:
                    print(
                        "Could not remove all residuals over shift_max threshold."
                    )
                    break
            print("removed {} in total".format(counter))

            # Estimate drift again
            drifts = np.matmul(np.linalg.pinv(coefs), shifts)

            print("{:.2f} s. RCC completed. Repeated solving shifts array.".
                  format(time.time() - self._start_time))

        # pad with 0 drift for first time point
        drifts = np.pad(drifts, [[1, 0], [0, 0]],
                        'constant',
                        constant_values=0)

        return t_shift, drifts

    def _execute(self, namespace):
        # dervied versions of RCC need to override this method
        # 'execute' of this RCC base class is not throughly tested as its use is probably quite limited.

        #        from PYME.util import mProfile

        self._start_time = time.time()
        print("Starting drift correction module.")

        if self.multiprocessing:
            proccess_count = np.clip(multiprocessing.cpu_count() - 1, 1, None)
            self._pool = multiprocessing.Pool(processes=proccess_count)


#        mProfile.profileOn(['localisations.py'])

        drift_res = self.calc_corr_drift_from_ft_images(self.cache_fft)
        t_shift, shifts = self.rcc(self.shift_max, *drift_res)
        #        mProfile.profileOff()
        #        mProfile.report()

        if self.multiprocessing:
            self._pool.close()
            self._pool.join()

        # convert frame-to-frame drift to drift from origin
        shifts = np.cumsum(shifts, 0)

        namespace[self.output_drift] = t_shift, shifts
Пример #13
0
class OctreeRenderLayer(TriangleRenderLayer):
    """
    Layer for viewing octrees. Takes in an octree, splits the faces into triangles, and then uses the rendering engines
    from PYME.LMVis.layers.triangle_mesh.
    """
    # Additional properties (the rest are inherited from TriangleRenderLayer)
    depth = Int(3, desc='Depth at which to render Octree. Set to -1 for dynamic depth rendering.')
    density = Float(0.0, desc='Minimum density of octree node to display.')
    min_points = Int(10, desc='Number of points/node to truncate octree at')

    def __init__(self, pipeline, method='wireframe', dsname='', context=None, **kwargs):
        TriangleRenderLayer.__init__(self, pipeline, method, dsname, context, **kwargs)

        self.on_trait_change(self.update, 'depth')
        self.on_trait_change(self.update, 'density')
        self.on_trait_change(self.update, 'min_points')
        
    @property
    def _ds_class(self):
        from PYME.experimental import octree
        return (octree.Octree, octree.PyOctree)

    def update_from_datasource(self, ds):
        """
        Opens an octree. Subdivides the faces into triangles. Feeds the triangle points/normals to update_data.

        Parameters
        ----------
        ds :
            Octree (see PYME.experimental.octree)

        Returns
        -------
        None
        """
        nodes = ds._nodes[ds._nodes[ds._nodes['parent']]['nPoints'] >= float(self.min_points)]
        
        if self.depth > 0:
            # Grab the nodes at the specified depth
            nodes = nodes[nodes['depth'] == self.depth]
            box_sizes = np.ones((nodes.shape[0], 3))*ds.box_size(self.depth)
            node_density = 1.*nodes['nPoints']/np.prod(box_sizes,axis=1)
            nodes = nodes[node_density >= self.density]
            box_sizes = np.ones((nodes.shape[0], 3))*ds.box_size(self.depth)
            alpha = nodes['nPoints']/box_sizes[:,0]
        elif self.depth == 0:
            # plot all bins
            nodes = nodes[nodes['nPoints'] >= 1]
            box_sizes = np.vstack(ds.box_size(nodes['depth'])).T

            alpha = nodes['nPoints'] * ((2 ** nodes['depth'])**3)
        else:
            # Plot leaf nodes
            nodes = nodes[(np.sum(nodes['children'],axis=1) == 0)&(nodes['depth'] > 0)]
            box_sizes = np.vstack(ds.box_size(nodes['depth'])).T
            
            alpha = nodes['nPoints']*((2.0**nodes['depth']))**3

        if len(nodes) > 0:
            c = nodes['centre']  # center
            shifts = (box_sizes[:,None]*OCT_SHIFT[None,:])*0.5
            v = (c[:,None,:] + shifts)

            #
            #     z
            #     ^
            #     |
            #    v4 ----------v6
            #    /|           /|
            #   / |          / |
            #  v5----------v7  |
            #  |  |    c    |  |
            #  | v0---------|-v2
            #  | /          | /
            #  v1-----------v3---> y
            #  /
            # x
            #
            # Now note that the counterclockwise triangles (when viewed straight-on) formed along the faces of the cube are:
            #
            # v0 v2 v1
            # v0 v1 v5
            # v0 v5 v4
            # v0 v6 v2
            # v0 v4 v6
            # v1 v2 v3
            # v1 v3 v7
            # v1 v7 v5
            # v2 v6 v7
            # v2 v7 v3
            # v4 v5 v6
            # v5 v7 v6

            # Counterclockwise triangles (when viewed straight-on) formed along
            # the faces of an octree box
            t0 = np.vstack(v[:,[0,0,0,0,0,1,1,1,2,2,4,5],:])
            t1 = np.vstack(v[:,[2,1,5,6,4,2,3,7,6,7,5,7],:])
            t2 = np.vstack(v[:,[1,5,4,2,6,3,7,5,7,3,6,6],:])

            x, y, z = np.hstack([t0,t1,t2]).reshape(-1, 3).T  # positions

            # Now we create the normal as the cross product
            tn = np.cross((t2-t1),(t0-t1))

            # We copy the normals 3 times per triangle to get 3x(3N) normals to match the vertices shape
            xn, yn, zn = np.repeat(tn.T, 3, axis=1)  # normals

            # Color is fixed constnat for octree
            c = np.ones(len(x))
            clim = [0, 1]
            
            alpha = self.alpha*alpha/alpha.max()
            alpha = (alpha[None,:]*np.ones(12)[:,None])
            alpha = np.repeat(alpha.ravel(), 3)
            print('Octree scaled alpha range: %g, %g' % (alpha.min(), alpha.max()))

            cmap = getattr(cm, self.cmap)

            # Do we have coordinates? Concatenate into vertices.
            if x is not None and y is not None and z is not None:
                vertices = np.vstack((x.ravel(), y.ravel(), z.ravel()))
                self._vertices = vertices.T.ravel().reshape(len(x.ravel()), 3)

                if not xn is None:
                    self._normals = np.vstack((xn.ravel(), yn.ravel(), zn.ravel())).T.ravel().reshape(len(x.ravel()), 3)
                else:
                    self._normals = -0.69 * np.ones(self._vertices.shape)

                self._bbox = np.array([x.min(), y.min(), z.min(), x.max(), y.max(), z.max()])
            else:
                self._bbox = None

            if clim is not None and c is not None and cmap is not None:
                cs_ = ((c - clim[0]) / (clim[1] - clim[0]))
                cs = cmap(cs_)

                if self.method in ['flat', 'tessel']:
                    alpha = cs_ * alpha
                
                cs[:, 3] = alpha
                
                if self.method == 'tessel':
                    cs = np.power(cs, 0.333)

                self._colors = cs.ravel().reshape(len(c), 4)
            else:
                # cs = None
                if not self._vertices is None:
                    self._colors = np.ones((self._vertices.shape[0], 4), 'f')

            print('Colors: {}'.format(self._colors))
                
            self._alpha = alpha
            self._color_map = cmap
            self._color_limit = clim
        else:
            print('No nodes for density {0}, depth {1}'.format(self.density, self.depth))
        

    @property
    def default_view(self):
        from traitsui.api import View, Item, Group, InstanceEditor, EnumEditor
        from PYME.ui.custom_traits_editors import HistLimitsEditor, CBEditor

        return View([#Group([
                     Item('dsname', label='Data', editor=EnumEditor(name='_datasource_choices')),
                     Item('method'),
                     Item('depth'),Item('min_points'),
                     #Item('vertexColour', editor=EnumEditor(name='_datasource_keys'), label='Colour'),
                     #Group([Item('clim', editor=HistLimitsEditor(data=self._get_cdata), show_label=False), ]),
                     Group([Item('cmap', label='LUT'), Item('alpha')])], )
Пример #14
0
class OctreeRenderLayer(TriangleRenderLayer):
    """
    Layer for viewing octrees. Takes in an octree, splits the faces into triangles, and then uses the rendering engines
    from PYME.LMVis.layers.triangle_mesh.
    """
    # properties to show in the GUI. Note that we also inherit 'visible' from BaseLayer
    #vertexColour = CStr('', desc='Name of variable used to colour our points')
    #cmap = Enum(*cm.cmapnames, default='gist_rainbow', desc='Name of colourmap used to colour faces')
    #clim = ListFloat([0, 1], desc='How our variable should be scaled prior to colour mapping')
    #alpha = Float(1.0, desc='Face tranparency')
    depth = Int(
        3,
        desc=
        'Depth at which to render Octree. Set to -1 for dynamic depth rendering.'
    )
    density = Float(0.0, desc='Minimum density of octree node to display.')
    min_points = Int(10, desc='Number of points/node to truncate octree at')

    #method = Enum(*ENGINES.keys(), desc='Method used to display faces')

    def __init__(self, pipeline, method='wireframe', dsname='', **kwargs):
        TriangleRenderLayer.__init__(self, pipeline, method, dsname, **kwargs)

        self.on_trait_change(self.update, 'depth')
        self.on_trait_change(self.update, 'density')
        self.on_trait_change(self.update, 'min_points')

    @property
    def _ds_class(self):
        from PYME.experimental import octree
        return (octree.Octree, octree.PyOctree)

    def update_from_datasource(self, ds):
        """
        Opens an octree. Subdivides the faces into triangles. Feeds the triangle points/normals to update_data.

        Parameters
        ----------
        ds :
            Octree (see PYME.experimental.octree)

        Returns
        -------
        None
        """
        nodes = ds._nodes[ds._nodes[ds._nodes['parent']]['nPoints'] >= float(
            self.min_points)]

        if self.depth > 0:
            # Grab the nodes at the specified depth
            nodes = nodes[nodes['depth'] == self.depth]
            box_sizes = np.ones((nodes.shape[0], 3)) * ds.box_size(self.depth)
            node_density = 1. * nodes['nPoints'] / np.prod(box_sizes, axis=1)
            nodes = nodes[node_density >= self.density]
            box_sizes = np.ones((nodes.shape[0], 3)) * ds.box_size(self.depth)
            alpha = nodes['nPoints'] / box_sizes[:, 0]
        elif self.depth == 0:
            # plot all bins
            nodes = nodes[nodes['nPoints'] >= 1]
            box_sizes = np.vstack(ds.box_size(nodes['depth'])).T

            alpha = nodes['nPoints'] * ((2**nodes['depth'])**3)

        else:
            # Follow the nodes until we reach a terminating node, then append this node to our list of nodes to render
            # Start at the 0th node
            # children = ds._nodes[0]['children'].tolist()
            # node_indices = []
            # box_sizes = []
            # # Do this until we've looked at the whole octree (the list of children is empty)
            # while children:
            #     # Check the child
            #     node_index = children.pop()
            #     curr_node = ds._nodes[node_index]
            #     # Is this a terminating node?
            #     if np.any(curr_node['children']):
            #         # It's not, so we'll add the children to the list
            #         new_children = curr_node['children'][curr_node['children'] > 0]
            #         children.extend(new_children)
            #     else:
            #         # Terminating node! We want to render this
            #         node_indices.append(node_index)
            #         box_sizes.append(ds.box_size(curr_node['depth']))

            # We've followed the octree to the end, return the nodes and box sizes
            #nodes = ds._nodes[node_indices]
            #box_sizes = np.array(box_sizes)

            nodes = nodes[np.sum(nodes['children']) == 0]
            box_sizes = np.vstack(ds.box_size(nodes['depth'])).T

            alpha = nodes['nPoints'] * ((2.0**nodes['depth']))**3

        # First we need the vertices of the cube. We find them from the center c provided and the box size (lx, ly, lz)
        # provided by the octree:

        if len(nodes) > 0:
            c = nodes['centre']  # center
            v0 = c + box_sizes * -1 / 2  # v0 = c - lx/2 - ly/2 - lz/2
            v1 = c + box_sizes * [-1, -1, 1] / 2  # v1 = c - lx/2 - ly/2 + lz/2
            v2 = c + box_sizes * [-1, 1, -1] / 2  # v2 = c - lx/2 + ly/2 - lz/2
            v3 = c + box_sizes * [-1, 1, 1] / 2  # v3 = c - lx/2 + ly/2 + lz/2
            v4 = c + box_sizes * [1, -1, -1] / 2  # v4 = c + lx/2 - ly/2 - lz/2
            v5 = c + box_sizes * [1, -1, 1] / 2  # v5 = c + lx/2 - ly/2 + lz/2
            v6 = c + box_sizes * [1, 1, -1] / 2  # v6 = c + lx/2 + ly/2 - lz/2
            v7 = c + box_sizes / 2  # v7 = c + lx/2 + ly/2 + lz/2

            #
            #     z
            #     ^
            #     |
            #    v1 ----------v3
            #    /|           /|
            #   / |          / |
            #  v5----------v7  |
            #  |  |    c    |  |
            #  | v0---------|-v2
            #  | /          | /
            #  v4-----------v6---> y
            #  /
            # x
            #
            # Now note that the counterclockwise triangles (when viewed straight-on) formed along the faces of the cube are:
            #
            # v0 v2 v6
            # v0 v4 v5
            # v1 v3 v2
            # v2 v0 v1
            # v3 v1 v5
            # v3 v7 v6
            # v4 v6 v7
            # v5 v1 v0
            # v5 v7 v3
            # v6 v2 v3
            # v6 v4 v0
            # v7 v5 v4

            # Concatenate vertices, interleave, restore to 3x(3N) points (3xN triangles),
            # and assign the points to x, y, z vectors
            triangle_v0 = np.vstack(
                (v0, v0, v1, v2, v3, v3, v4, v5, v5, v6, v6, v7))
            triangle_v1 = np.vstack(
                (v2, v4, v3, v0, v1, v7, v6, v1, v7, v2, v4, v5))
            triangle_v2 = np.vstack(
                (v6, v5, v2, v1, v5, v6, v7, v0, v3, v3, v0, v4))

            x, y, z = np.hstack(
                (triangle_v0, triangle_v1, triangle_v2)).reshape(-1, 3).T

            # Now we create the normal as the cross product (triangle_v2 - triangle_v1) x (triangle_v0 - triangle_v1)
            triangle_normals = np.cross((triangle_v2 - triangle_v1),
                                        (triangle_v0 - triangle_v1))
            # We copy the normals 3 times per triangle to get 3x(3N) normals to match the vertices shape
            xn, yn, zn = np.repeat(triangle_normals.T, 3, axis=1)

            alpha = self.alpha * alpha / alpha.max()
            alpha = (alpha[None, :] * np.ones(12)[:, None])
            alpha = np.repeat(alpha.ravel(), 3)
            print('Octree scaled alpha range: %g, %g' %
                  (alpha.min(), alpha.max()))

            # Pass the restructured data to update_data
            self.update_data(x,
                             y,
                             z,
                             cmap=getattr(cm, self.cmap),
                             clim=self.clim,
                             alpha=alpha,
                             xn=xn,
                             yn=yn,
                             zn=zn)
        else:
            print('No nodes for density {0}, depth {1}'.format(
                self.density, self.depth))

    @property
    def default_view(self):
        from traitsui.api import View, Item, Group, InstanceEditor, EnumEditor
        from PYME.ui.custom_traits_editors import HistLimitsEditor, CBEditor

        return View(
            [  #Group([
                Item('dsname',
                     label='Data',
                     editor=EnumEditor(name='_datasource_choices')),
                Item('method'),
                Item('depth'),
                Item('min_points'),
                #Item('vertexColour', editor=EnumEditor(name='_datasource_keys'), label='Colour'),
                #Group([Item('clim', editor=HistLimitsEditor(data=self._get_cdata), show_label=False), ]),
                Group([Item('cmap', label='LUT'),
                       Item('alpha')])
            ], )
Пример #15
0
class ImageRenderLayer(EngineLayer):
    """
    Layer for viewing images.
    """
    # properties to show in the GUI. Note that we also inherit 'visible' from BaseLayer
    cmap = Enum(*cm.cmapnames, default='gray', desc='Name of colourmap used to colour faces')
    clim = ListFloat([0, 1], desc='How our data should be scaled prior to colour mapping')
    alpha = Float(1.0, desc='Tranparency')
    method = Enum(*ENGINES.keys(), desc='Method used to display image')
    dsname = CStr('output', desc='Name of the datasource within the pipeline to use as an image')
    channel = Int(0)
    slice = Int(0)
    z_pos = Float(0)
    _datasource_choices = List()
    _datasource_keys = List()

    def __init__(self, pipeline, method='image', dsname='', display_opts=None, context=None, **kwargs):
        EngineLayer.__init__(self, context=context, **kwargs)
        self._pipeline = pipeline
        self.engine = None
        self.cmap = 'gray'

        self._bbox = None
        self._do = display_opts #a dh5view display_options instance - if provided, this over-rides the the clim, cmap properties
        
        self._im_key = None

        # define a signal so that people can be notified when we are updated (currently used to force a redraw when
        # parameters change)
        self.on_update = dispatch.Signal()

        # define responses to changes in various traits
        #self.on_trait_change(self._update, 'vertexColour')
        self.on_trait_change(lambda: self.on_update.send(self), 'visible')
        self.on_trait_change(self.update, 'cmap, clim, alpha, dsname')
        self.on_trait_change(self._set_method, 'method')

        # update any of our traits which were passed as command line arguments
        self.set(**kwargs)

        # update datasource and method
        self.dsname = dsname
        if self.method == method:
            #make sure we still call _set_method even if we start with the default method
            self._set_method()
        else:
            self.method = method

        # if we were given a pipeline, connect ourselves to the onRebuild signal so that we can automatically update
        # ourselves
        if (not self._pipeline is None) and hasattr(pipeline, 'onRebuild'):
            self._pipeline.onRebuild.connect(self.update)

    @property
    def datasource(self):
        """
        Return the datasource we are connected to (does not go through the pipeline for triangles_mesh).
        """
        try:
            return self._pipeline.get_layer_data(self.dsname)
        except AttributeError:
            #fallback if pipeline is a dictionary
            return self._pipeline[self.dsname]
        #return self.datasource
    
    @property
    def _ds_class(self):
        # from PYME.experimental import triangle_mesh
        from PYME.IO import image
        return image.ImageStack

    def _set_method(self):
        self.engine = ENGINES[self.method](self._context)
        self.update()


    # def _update(self, *args, **kwargs):
    #     #pass
    #     cdata = self._get_cdata()
    #     self.clim = [float(cdata.min()), float(cdata.max())]
    #     self.update(*args, **kwargs)

    def update(self, *args, **kwargs):
        try:
            self._datasource_choices = [k for k, v in self._pipeline.dataSources.items() if isinstance(v, self._ds_class)]
        except AttributeError:
            self._datasource_choices = [k for k, v in self._pipeline.items() if
                                        isinstance(v, self._ds_class)]
        
        if not (self.engine is None or self.datasource is None):
            print('lw update')
            self.update_from_datasource(self.datasource)
            self.on_update.send(self)

    @property
    def bbox(self):
        return self._bbox
    
    def sync_to_display_opts(self, do=None):
        if (do is None):
            if not (self._do is None):
                do = self._do
            else:
                return

        o = do.Offs[self.channel]
        g = do.Gains[self.channel]
        clim = [o, o + 1.0 / g]

        cmap = do.cmaps[self.channel].name
        visible = do.show[self.channel]
        
        self.set(clim=clim, cmap=cmap, visible=visible)
        

    def update_from_datasource(self, ds):
        """

        Parameters
        ----------
        ds :
            PYME.IO.image.ImageStack object

        Returns
        -------
        None
        """

        
        #if self._do is not None:
            # Let display options (if provied) over-ride our settings (TODO - is this the right way to do this?)
        #    o = self._do.Offs[self.channel]
        #    g = self._do.Gains[self.channel]
        #    clim = [o, o + 1.0/g]
            #self.clim = clim
            
        #    cmap = self._do.cmaps[self.channel]
            #self.visible = self._do.show[self.channel]
        #else:
        
        clim = self.clim
        cmap = getattr(cm, self.cmap)
            
        alpha = float(self.alpha)
        
        c0, c1 = clim
        
        im_key = (self.dsname, self.slice, self.channel)
        
        if not self._im_key == im_key:
            self._im_key = im_key
            self._im = ds.data[:,:,self.slice, self.channel].astype('f4')# - c0)/(c1-c0)
        
            x0, y0, x1, y1, _, _ = ds.imgBounds.bounds

            self._bbox = np.array([x0, y0, 0, x1, y1, 0])
        
            self._bounds = [x0, y0, x1, y1]
            
        self._alpha = alpha
        self._color_map = cmap
        self._color_limit = clim

    def get_color_map(self):
        return self._color_map

    @property
    def colour_map(self):
        return self._color_map

    def get_color_limit(self):
        return self._color_limit
    
    def _get_cdata(self):
        return self._im.ravel()[::20]

    @property
    def default_view(self):
        from traitsui.api import View, Item, Group, InstanceEditor, EnumEditor
        from PYME.ui.custom_traits_editors import HistLimitsEditor, CBEditor

        return View([Group([Item('dsname', label='Data', editor=EnumEditor(name='_datasource_choices')), ]),
                     #Item('method'),
                     Group([Item('clim', editor=HistLimitsEditor(data=self._get_cdata), show_label=False), ]),
                     Group([Item('cmap', label='LUT'),
                            Item('alpha', visible_when='method in ["flat", "tessel"]')
                            ])
                     ], )
        # buttons=['OK', 'Cancel'])

    def default_traits_view(self):
        return self.default_view
class FilterSpikes(ModuleBase):
    """
    Using a rolling window along the time (/ z) dimension, identify spikes which are greatly above the median within
    that window and remove them by replacing the value with the median.

    Parameters
    ----------
    series: PYME.IO.image.ImageStack
    time_window_size: int
        Size of window to use in rolling-median and standard deviation calculations
    threshold_factor: float
        Multiplicative factor used to set the spike threshold, which is threshold * median-absolute deviation + median,
        all calculated within the window for an individual x, y, pixel.
    threshold_change: float
        Absolute change required in a single time-step for a spike candidate to be considered a spike

    Returns
    -------
    output: PYME.IO.image.ImageStack
        Spike-filtered copy of the input series


    Notes
    -----
    Currently only set up for single-color data
    """

    input = Input('input')

    time_window_size = Int(10)
    threshold_factor = Float(5)
    threshold_change = Float(370)

    process_frames_individually = False

    output = Output('filtered')

    def execute(self, namespace):
        from scipy.stats import median_absolute_deviation
        series = namespace[self.input]

        diff = np.diff(series.data[:, :, :, 0]).squeeze()
        over_jump_threshold = np.zeros(series.data.shape[:-1], dtype=bool)
        over_jump_threshold[:, :, 1:] = diff > self.threshold_change

        output = np.copy(series.data[:, :, :,
                                     0].squeeze())  # only 1 color for now

        for ti in range(series.data.shape[2] - self.time_window_size):
            data = output[:, :, ti:ti + self.time_window_size]
            median = np.median(data, axis=2)
            spikes = np.logical_and(
                data > (self.threshold_factor *
                        median_absolute_deviation(data, scale=1, axis=2) +
                        median)[:, :, None],
                over_jump_threshold[:, :, ti:ti + self.time_window_size])
            spike_locs = np.nonzero(spikes)
            output[spike_locs[0], spike_locs[1],
                   spike_locs[2] + ti] = median[spike_locs[0], spike_locs[1]]

        out = image.ImageStack(data=output)
        out.mdh = MetaDataHandler.NestedClassMDHandler()
        try:
            out.mdh.copyEntriesFrom(series.mdh)
        except AttributeError:
            pass
        out.mdh[
            'Analysis.FilterSpikes.ThresholdFactor'] = self.threshold_factor
        out.mdh[
            'Analysis.FilterSpikes.ThresholdChange'] = self.threshold_change
        out.mdh['Analysis.FilterSpikes.TimeWindowSize'] = self.time_window_size

        namespace[self.output] = out
Пример #17
0
class FiducialTrack(ModuleBase):
    """
    Extract average fiducial track from input pipeline

    Parameters
    ----------

        radiusMultiplier: this number is multiplied with error_x to obtain search radius for clustering
        timeWindow: the window along the time dimension used for clustering
        filterScale: the size of the filter kernel used to smooth the resulting average fiducial track
        filterMethod: enumrated choice of filter methods for smoothing operation (Gaussian, Median or Uniform kernel)

    Notes
    -----

    Output is a new pipeline with added fiducial_x, fiducial_y columns

    """
    import PYMEcs.Analysis.trackFiducials as tfs
    inputName = Input('filtered')

    radiusMultiplier = Float(5.0)
    timeWindow = Int(25)
    filterScale = Float(11)
    filterMethod = Enum(tfs.FILTER_FUNCS.keys())
    clumpMinSize = Int(50)
    singleFiducial = Bool(True)

    outputName = Output('fiducialAdded')

    def execute(self, namespace):
        import PYMEcs.Analysis.trackFiducials as tfs

        inp = namespace[self.inputName]
        mapped = tabular.mappingFilter(inp)

        if self.singleFiducial:
            # if all data is from a single fiducial we do not need to align
            # we then avoid problems with incomplete tracks giving rise to offsets between
            # fiducial track fragments
            align = False
        else:
            align = True

        t, x, y, z, isFiducial = tfs.extractTrajectoriesClump(
            inp,
            clumpRadiusVar='error_x',
            clumpRadiusMultiplier=self.radiusMultiplier,
            timeWindow=self.timeWindow,
            clumpMinSize=self.clumpMinSize,
            align=align)
        rawtracks = (t, x, y, z)
        tracks = tfs.AverageTrack(inp,
                                  rawtracks,
                                  filter=self.filterMethod,
                                  filterScale=self.filterScale,
                                  align=align)

        # add tracks for all calculated dims to output
        for dim in tracks.keys():
            mapped.addColumn('fiducial_%s' % dim, tracks[dim])
        mapped.addColumn('isFiducial', isFiducial)

        # propogate metadata, if present
        try:
            mapped.mdh = inp.mdh
        except AttributeError:
            pass

        namespace[self.outputName] = mapped

    @property
    def hide_in_overview(self):
        return ['columns']
Пример #18
0
class RGBImageUpload(ImageUpload):
    """
    Create RGB (png) image and upload it to an OMERO server, optionally
    attaching localization files.

    Parameters
    ----------
    input_image : str
        name of image in the recipe namespace to upload
    input_localization_attachments : dict
        maps tabular types (keys) to attachment filenames (values). Tabular's
        will be saved as '.hdf' files and attached to the image
    filePattern : str
        pattern to determine name of image on OMERO server
    omero_dataset : str
        name of OMERO dataset to add the image to. If the dataset does not
        already exist it will be created.
    omero_project : str
        name of OMERO project to link the dataset to. If the project does not
        already exist it will be created.
    zoom : float
        how large to zoom the image
    scaling : str 
        how to scale the intensity - one of 'min-max' or 'percentile'
    scaling_factor: float
        `percentile` scaling only - which percentile to use
    colorblind_friendly : bool
        Use cyan, magenta, and yellow rather than RGB. True, by default.
    
    Notes
    -----
    OMERO server address and user login information must be stored in the user
    PYME config directory under plugins/config/pyme-omero, e.g. 
    /Users/Andrew/.PYME/plugins/config/pyme-omero. The file should be yaml
    formatted with the following keys:
        user
        password
        address
        port [optional]
    
    The project/image/dataset will be owned by the user set in the yaml file.
    """
    filePattern = '{file_stub}.png'
    scaling = Enum(['min-max', 'percentile'])
    scaling_factor = Float(0.95)
    zoom = Int(1)
    colorblind_friendly = Bool(True)

    def _save(self, image, path):
        from PIL import Image
        from PYME.IO.rgb_image import image_to_rgb, image_to_cmy

        if (self.colorblind_friendly and (image.data.shape[3] != 1)):
            im = image_to_cmy(image,
                              zoom=self.zoom,
                              scaling=self.scaling,
                              scaling_factor=self.scaling_factor)
        else:
            im = image_to_rgb(image,
                              zoom=self.zoom,
                              scaling=self.scaling,
                              scaling_factor=self.scaling_factor)

        rgb = Image.fromarray(im, mode='RGB')
        rgb.save(path)
Пример #19
0
class HDBSCANClustering(ModuleBase):
    """
    Performs HDBSCAN clustering on input dictionary

    Parameters
    ----------

        minPtsForCore: The minimum size of clusters. Technically the only required parameter.
        searchRadius: Extract DBSCAN clustering based on search radius. Skipped if 0 or None.

    Notes
    -----

    See https://github.com/scikit-learn-contrib/hdbscan
    Lots of other parameters not mapped.

    """
    input_name = Input('filtered')
    # input_vert = Input('vert')
    columns = ListStr(['x', 'y'])
    search_radius = Float()
    min_clump_size = Int(100)
    clump_column_name = CStr('hdbscan_id')
    clump_prob_column_name = CStr('hdbscan_prob')
    clump_dbscan_column_name = CStr('dbscan_id')
    output_name = Output('hdbscan_clustered')

    def execute(self, namespace):

        # print('testing showpoints again')
        # print(namespace['showplots'])
        inp = namespace[self.input_name]
        mapped = tabular.mappingFilter(inp)
        # vert_data = namespace[self.input_vert]
        import hdbscan
        clusterer = hdbscan.HDBSCAN(min_cluster_size=self.min_clump_size)

        clusterer.fit(np.vstack([inp[k] for k in self.columns]).T)

        # Note that hdbscan gives unclustered points label of -1, and first value starts at 0.
        # shift hdbscan labels up by one to match existing convention that a clumpID of 0 corresponds to unclumped
        mapped.addColumn(str(self.clump_column_name), clusterer.labels_ + 1)
        mapped.addColumn(str(self.clump_prob_column_name),
                         clusterer.probabilities_)

        if not self.search_radius is None and self.search_radius > 0:
            #Extract dbscan clustering from hdbscan clusterer
            dbscan = clusterer.single_linkage_tree_.get_clusters(
                self.search_radius, self.min_clump_size)

            # shift dbscan labels up by one to match existing convention that a clumpID of 0 corresponds to unclumped
            mapped.addColumn(str(self.clump_dbscan_column_name), dbscan + 1)

        # propogate metadata, if present
        try:
            mapped.mdh = inp.mdh
            print('testing for mdh')
        except AttributeError:
            pass

        namespace[self.output_name] = mapped
        print('finished clustering')
class DetectPSF(ModuleBase):
    """
        Detect PSF based on diff of gaussian
        Image dims in X, Y, Z, C where C are processed independently.
        Returns list of (X, Y, Z) per C
    """
    
    inputName = Input('input')
    
    min_sigma = Float(1.0)
    max_sigma = Float(3.0)
    sigma_ratio = Float(1.6)
    percent_threshold = Float(0.1)
    overlap = Float(0.5)
    exclude_border = Int(50)
    ignore_z = Bool(True)
    
    output_pos = Output('psf_pos')
#    output_img = Output('output')
    
    def execute(self, namespace):
        ims = namespace[self.inputName]
        
        pixel_size = ims.mdh['voxelsize.x']
        
        pos = list()
        counts = ims.data.shape[3]
        for c in np.arange(counts):
            mean_project = ims.data[:,:,:,c].mean(2).squeeze()
            mean_project[mean_project==2**16-1] = 200
            
            mean_project -= mean_project.min()
            mean_project /= mean_project.max()
            
            # if skimage is new enough to support exclude_border
            #blobs = feature.blob_dog(mean_project, self.min_sigma / pixel_size, self.max_sigma / pixel_size, overlap=self.overlap, threshold=self.percent_threshold*mean_project.max(), exclude_border=self.exclude_border)
            
            #otherwise:
            blobs = feature.blob_dog(mean_project, self.min_sigma / pixel_size, self.max_sigma / pixel_size, overlap=self.overlap, threshold=self.percent_threshold*mean_project.max())

            edge_mask = (blobs[:, 0] > self.exclude_border) & (blobs[:, 0] < mean_project.shape[0] - self.exclude_border)
            edge_mask &= (blobs[:, 1] > self.exclude_border) & (blobs[:, 1] < mean_project.shape[1] - self.exclude_border)
            blobs = blobs[edge_mask]
            # is list of x, y, sig
            if self.ignore_z:
                blobs = np.insert(blobs, 2, ims.data.shape[2]//2, axis=1)
            else:
                raise Exception("z centering not yet implemented")
            blobs = blobs.astype(np.int)
#            print blobs
                
            pos.append(blobs)
        namespace[self.output_pos] = pos

        if True:
            try:
#                from matplotlib import pyplot
                fig, axes = pyplot.subplots(1, counts, figsize=(4*counts, 3), squeeze=False)
                for c in np.arange(counts):
                    mean_project = ims.data[:,:,:,c].mean(2).squeeze()
                    mean_project[mean_project==2**16-1] = 200
                    axes[0, c].imshow(mean_project)
                    axes[0, c].set_axis_off()
                    for x, y, z, sig in pos[c]:
                        cir = pyplot.Circle((y, x), sig, color='red', linewidth=2, fill=False)
                        axes[0, c].add_patch(cir)
            except Exception as e:
                print e
class CropPSF(ModuleBase):
    """
        Crops out PSF based on positions given.
        Built-in filter by flattened index.
        Built-in filter for removing multiple peaked data.
        Filters work on flatten X, Y images
        Stacked in the 4 dimension
    """
    
    inputName = Input('input')
    input_pos = Input('psf_pos')
    
    ignore_pos = List(Int, [])
    threshold_reject = Float(0.5)
    com_reject = Float(2.0)
    
    half_roi_x = Int(20)
    half_roi_y = Int(20)
    half_roi_z = Int(60)
    
    output_images = Output('psf_cropped')
    
    def execute(self, namespace):
        ims = namespace[self.inputName]
        psf_pos = namespace[self.input_pos]
        
        res = np.zeros((self.half_roi_x*2+1, self.half_roi_y*2+1, self.half_roi_z*2+1, sum([ar.shape[0] for ar in psf_pos])))
        
        mask = np.ones(res.shape[3], dtype=bool)
        counter = 0
        for c in np.arange(ims.data.shape[3]):
            for i in np.arange(len(psf_pos[c])):
#                print psf_pos[c][i][:3]
                x, y, z = psf_pos[c][i][:3]
                x_slice = slice(x-self.half_roi_x, x+self.half_roi_x+1)
                y_slice = slice(y-self.half_roi_y, y+self.half_roi_y+1)
                z_slice = slice(z-self.half_roi_z, z+self.half_roi_z+1)
                res[:, :, :, counter] = ims.data[x_slice, y_slice, z_slice, c].squeeze()
                
                crop_flatten = res[:, :, :, counter].mean(2)
                failed = False
                labeled_image, labeled_counts = ndimage.label(crop_flatten > crop_flatten.max() * self.threshold_reject)
                if labeled_counts > 1:
                    failed = True
                else:
                    com = np.asarray(ndimage.center_of_mass(crop_flatten, labeled_image, 1))
                    img_center = np.asarray([(s-1)*0.5 for s in labeled_image.shape])
                    dist = np.linalg.norm(com - img_center)
#                    print(com, img_center, dist)
                    if dist > self.com_reject:
                        failed = True
                    
                if failed and counter not in self.ignore_pos:
                    self.ignore_pos.append(counter)                    
                
                counter += 1

        # To do: add metadata
#        mdh['ImageType=']='PSF'        
        print "images ignore: {}".format(self.ignore_pos)
        mask[self.ignore_pos] = False
        
        new_mdh = None
        try:
            new_mdh = MetaDataHandler.NestedClassMDHandler(ims.mdh)
            new_mdh["ImageType"] = 'PSF'
            if not "PSFExtraction.SourceFilenames" in new_mdh.keys():
                new_mdh["PSFExtraction.SourceFilenames"] = ims.filename
        except Exception as e:
            print(e)
            
        namespace[self.output_images] = ImageStack(data=res[:,:,:,mask], mdh=new_mdh)
class ShiftImage(CacheCleanupModule):
    """
    Performs FT based image shift. Only shift in 2D.    
        
    Inputs
    ------
    input_image : ImageStack
        Images with drift.
    input_drift_interpolator : 
        Returns drift when called with frame number / time.
    
    Outputs
    -------
    outputName : ImageStack
    
    Parameters
    ----------
    padding_multipler : Int
        Padding (as multiple of image size) added to the image before shifting to avoid artifacts.
    cache_image : File
        Use file as disk cache if provided.
    """

    input_image = Input('input')
    #    input_shift = Input('drift')
    input_drift_interpolator = Input('drift_interpolator')
    padding_multipler = Int(1)

    #    ft_cache = File("ft_images.bin")
    cache_image = File("shifted_image.bin")
    #    image_cache_2 = File("rcc_shifted_image_2.bin")
    outputName = Output('drift_corrected_image')

    def _execute(self, namespace):
        self._start_time = time.time()
        #        try:
        ##            del self._ft_images
        #            del self.image_cache
        #        except:
        #            pass

        ims = namespace[self.input_image]

        t_out = np.arange(ims.data.shape[2], dtype=np.float)

        if 'recipe.binning' in ims.mdh.keys():
            t_out *= ims.mdh['recipe.binning'][2]
            t_out += 0.5 * ims.mdh['recipe.binning'][2]
#        print t_out

        dx = namespace[self.input_drift_interpolator][0](t_out)
        dy = namespace[self.input_drift_interpolator][1](t_out)

        shifted_images = self.shift_images(ims, np.stack([dx, dy], 1), ims.mdh)

        namespace[self.outputName] = ImageStack(shifted_images,
                                                titleStub=self.outputName,
                                                mdh=ims.mdh)

    def shift_images(self, ims, shifts, mdh):

        padding = np.stack((ims.data.shape[:2], ) * 2, -1)  #2d only
        padding *= self.padding_multipler

        padded_image_shape = np.asarray(ims.data.shape[:2],
                                        dtype=np.long) + padding.sum((1))

        dtype = ims.data[:, :, 0].dtype
        padded_image = np.zeros(padded_image_shape, dtype=dtype)

        kx = (np.fft.fftfreq(padded_image_shape[0]))
        ky = (np.fft.fftfreq(padded_image_shape[1]))
        #        kz = (np.fft.fftfreq(self._ft_images.shape[3])) * 0.5
        #        kx, ky, kz = np.meshgrid(kx, ky, kz, indexing='ij')
        kx, ky = np.meshgrid(kx, ky, indexing='ij')

        images_shape = np.asarray(ims.data.shape[:3], dtype=np.long)
        images_shape = tuple(images_shape)

        if self.cache_image == "":
            shifted_images = np.empty(images_shape)
        else:
            shifted_images = np.memmap(self.cache_image,
                                       dtype=np.float,
                                       mode='w+',
                                       shape=images_shape)

#        print(shifts.shape)
#        print(kx.shape, ky.shape)

#        print shifts
        shifts_in_pixels = np.copy(shifts)

        try:
            shifts_in_pixels[:, 0] = shifts[:, 0] / mdh.voxelsize.x
            shifts_in_pixels[:, 1] = shifts[:, 1] / mdh.voxelsize.y
            #            shifts_in_pixels[:, 2] = shifts[:, 2] / mdh.voxelsize.z

            #            shifts_in_pixels[np.isnan(shifts_in_pixels)] = 0

            #            print mdh
            if mdh.voxelsize.units == 'um':
                #                print('um units')
                shifts_in_pixels /= 1E3
        except Exception as e:
            Warning("Failed at converting drift in pixels to real distances")
            repr(e)


#        print shifts_in_pixels

        for i in np.arange(ims.data.shape[2]):
            #            print i

            padded_image[padding[0, 0]:padding[0, 0] + ims.data.shape[0],
                         padding[1, 0]:padding[1, 0] +
                         ims.data.shape[1]] = ims.data[:, :, i].squeeze()

            ft_image = np.fft.fftn(padded_image)

            data_shifted = shift_image_direct_rough(ft_image,
                                                    shifts_in_pixels[i],
                                                    kxy=(kx, ky))

            shifted_images[:, :,
                           i] = data_shifted[padding[0, 0]:padding[0, 0] +
                                             ims.data.shape[0],
                                             padding[1, 0]:padding[1, 0] +
                                             ims.data.shape[1]]

            if ((i + 1) % max(shifted_images.shape[-1] // 5, 1) == 0):
                if isinstance(shifted_images, np.memmap):
                    shifted_images.flush()
                print("{:.2f} s. Completed shifting {} of {} total images.".
                      format(time.time() - self._start_time, i + 1,
                             shifted_images.shape[-1]))

        return shifted_images
class AlignPSF(ModuleBase):
    """
        Align PSF stacks by redundant cross correlation.
    """
    
    inputName = Input('psf_cropped')
    normalize_z = Bool(True)
    tukey = Float(0.50)
    rcc_tolerance = Float(5.0)
    z_crop_half_roi = Int(15)
    peak_detect = Enum(['Gaussian', 'RBF'])
    debug = Bool(False)
    output_cross_corr_images = Output('cross_cor_img')
    output_cross_corr_images_fitted = Output('cross_cor_img_fitted')
    output_images = Output('psf_aligned')
    
    def execute(self, namespace):
        self._namespace = namespace
        ims = namespace[self.inputName]
        
        # X, Y, Z, 'C'
        psf_stack = ims.data[:,:,:,:]
        
        z_slice = slice(psf_stack.shape[2]//2-self.z_crop_half_roi, psf_stack.shape[2]//2+self.z_crop_half_roi+1)
        cleaned_psf_stack = self.normalize_images(psf_stack[:,:,z_slice,:])
            
        if self.tukey > 0:
            masks = [signal.tukey(dim_len, self.tukey) for dim_len in cleaned_psf_stack.shape[:3]]            
            masks = np.product(np.meshgrid(*masks, indexing='ij'), axis=0)            
            cleaned_psf_stack *= masks[:,:,:,None]
            
        drifts = self.calculate_shifts(cleaned_psf_stack, self.rcc_tolerance * 1E3 / ims.mdh['voxelsize.x'])
#        print drifts        
        
        namespace[self.output_images] = ImageStack(self.shift_images(cleaned_psf_stack if self.debug else psf_stack, drifts), mdh=ims.mdh)
        
    def normalize_images(self, psf_stack):
        # in case it is already bg subtracted
        cleaned_psf_stack = np.clip(psf_stack, 0, None)
        
        # substact bg per stack
        cleaned_psf_stack -= cleaned_psf_stack.min(axis=(0,1,2), keepdims=True)
        
        if self.normalize_z:
            # normalize intensity per plane
            cleaned_psf_stack /= cleaned_psf_stack.max(axis=(0,1), keepdims=True) / 1.05
        else:
            # normalize intensity per psf stack
            cleaned_psf_stack /= cleaned_psf_stack.max(axis=(0,1,2), keepdims=True) / 1.05
            
        cleaned_psf_stack -= 0.05
        np.clip(cleaned_psf_stack, 0, None, cleaned_psf_stack)
        
        return cleaned_psf_stack

    def calculate_shifts(self, psf_stack, drift_tolerance):
        n_steps = psf_stack.shape[3]
        coefs_size = n_steps * (n_steps-1) / 2
        coefs = np.zeros((coefs_size, n_steps-1))
        shifts = np.zeros((coefs_size, 3))
        
        output_cross_corr_images = np.zeros((psf_stack.shape[0], psf_stack.shape[1], psf_stack.shape[2], coefs_size), dtype=np.float)
        output_cross_corr_images_fitted = np.zeros((psf_stack.shape[0], psf_stack.shape[1], psf_stack.shape[2], coefs_size), dtype=np.float)
        
        counter = 0
        for i in np.arange(0, n_steps - 1):
            for j in np.arange(i+1, n_steps):
                coefs[counter, i:j] = 1
                
                print "compare {} to {}".format(i, j)
                correlate_result = signal.correlate(psf_stack[:,:,:,i], psf_stack[:,:,:,j], mode="same")
                correlate_result -= correlate_result.min()
                correlate_result /= correlate_result.max()
        
                threshold = 0.50
                correlate_result[correlate_result<threshold] = np.nan
                
                labeled_image, labeled_counts = ndimage.label(~np.isnan(correlate_result))
#                print(labeled_counts)
                # protects against > 1 peak in the cross correlation results
                # shouldn't happen anyway, but at least avoid fitting a single to multi-modal data
                if labeled_counts > 1:
                    max_order = np.argsort(ndimage.maximum(correlate_result, labeled_image, np.arange(labeled_counts)+1))+1
                    correlate_result[labeled_image!=max_order[0]] = np.nan
                    
                
                output_cross_corr_images[:,:,:,counter] = np.nan_to_num(correlate_result)
                
                dims = list()
                for _, dim in enumerate(correlate_result.shape):
                    dims.append(np.arange(dim))
                    dims[-1] = dims[-1] - dims[-1].mean()
                
#                peaks = np.nonzero(correlate_result==np.nanmax(correlate_result))
                if self.peak_detect == "Gaussian":
                    res = optimize.least_squares(guassian_nd_error,
                                                 [1, 0, 0, 5., 0, 5., 0, 30.],
                                                 args=(dims, correlate_result))                
                    output_cross_corr_images_fitted[:,:,:,counter] = gaussian_nd(res.x, dims)
#                    print("Gaussian")
#                    print("chi2: {}".format(np.sum(np.square(res.fun))/(res.fun.shape[0]-8)))
#                    print("fitted parameters: {}".format(res.x))
                        
    #                res = optimize.least_squares(guassian_sq_nd_error,
    #                                             [1, 0, 0, 3., 0, 3., 0, 20.],
    #                                             args=(dims, correlate_result))                
    #                output_cross_corr_images_fitted[:,:,:,counter] = gaussian_sq_nd(res.x, dims)
    #                print("Gaussian 2")
    #                print("chi2: {}".format(np.sum(np.square(res.fun))/(res.fun.shape[0]-8)))
    #                print("fitted parameters: {}".format(res.x))
    #                
    #                res = optimize.least_squares(lorentzian_nd_error,
    #                                             [1, 0, 0, 2., 0, 2., 0, 10.],
    #                                             args=(dims, correlate_result))                
    #                output_cross_corr_images_fitted[:,:,:,counter] = lorentzian_nd(res.x, dims)
    #                print("lorentzian")
    #                print("chi2: {}".format(np.sum(np.square(res.fun))/(res.fun.shape[0]-8)))
    #                print("fitted parameters: {}".format(res.x))
                    
                    shifts[counter, 0] = res.x[2]
                    shifts[counter, 1] = res.x[4]
                    shifts[counter, 2] = res.x[6]
                elif self.peak_detect == "RBF":
                    rbf_interpolator = build_rbf(dims, correlate_result)
                    res = optimize.minimize(rbf_nd_error, [correlate_result.shape[0]*0.5, correlate_result.shape[1]*0.5, correlate_result.shape[2]*0.5], args=rbf_interpolator)
                    output_cross_corr_images_fitted[:,:,:,counter] = rbf_nd(rbf_interpolator, dims)
#                    print(res.x)
                    shifts[counter, :] = res.x
                else:
                    raise Exception("peak founding method not recognised")

#                print("fitted parameters: {}".format(res.x))
        
                counter += 1
                
        self._namespace[self.output_cross_corr_images] = ImageStack(output_cross_corr_images)
        self._namespace[self.output_cross_corr_images_fitted] = ImageStack(output_cross_corr_images_fitted)
        
        drifts = np.matmul(np.linalg.pinv(coefs), shifts)
        residuals = np.matmul(coefs, drifts) - shifts
        residuals_dist = np.linalg.norm(residuals, axis=1)
        
#        shift_max = self.rcc_tolerance * 1E3 / mdh['voxelsize.x']
        shift_max = drift_tolerance
        # Sort and mask residual errors
        residuals_arg = np.argsort(-residuals_dist)        
        residuals_arg = residuals_arg[residuals_dist[residuals_arg] > shift_max]

        # Remove coefs rows
        # Descending from largest residuals to small
        # Only if matrix remains full rank
        coefs_temp = np.empty_like(coefs)
        counter = 0
        for i, index in enumerate(residuals_arg):
            coefs_temp[:] = coefs
            coefs_temp[index, :] = 0
            if np.linalg.matrix_rank(coefs_temp) == coefs.shape[1]:
                coefs[:] = coefs_temp
        #                print("index {} with residual of {} removed".format(index, residuals_dist[index]))
                counter += 1
            else:
                print("Could not remove all residuals over shift_max threshold.")
                break
        print("removed {} in total".format(counter))
        drifts = np.matmul(np.linalg.pinv(coefs), shifts)
       
        drifts = np.pad(drifts, [[1,0],[0,0]], 'constant', constant_values=0)
        np.cumsum(drifts, axis=0, out=drifts)
        
        psf_stack_mean = psf_stack / psf_stack.mean(axis=(0,1,2), keepdims=True)
        psf_stack_mean = psf_stack_mean.mean(axis=3)
        psf_stack_mean *= psf_stack_mean > psf_stack_mean.max() * 0.5        
        center_offset = ndimage.center_of_mass(psf_stack_mean) - np.asarray(psf_stack_mean.shape)*0.5
#        print(center_offset)
        
#        print drifts.shape
#        print stats.trim_mean(drifts, 0.25, axis=0)
#        drifts = drifts - stats.trim_mean(drifts, 0.25, axis=0)
        
        drifts = drifts - center_offset
                
        if True:
            try:
#                from matplotlib import pyplot
                fig, axes = pyplot.subplots(1, 2, figsize=(6,3))
#                new_residuals = np.matmul(coefs, drifts) - shifts
#                new_residuals_dist = np.linalg.norm(new_residuals, axis=1)
#                # print new_residuals_dist
#                pyplot.hist(new_residuals_dist[coefs.any(axis=1)], 100)
#                print drifts                
#                limits = np.max(np.abs(drifts), axis=0)
                
                axes[0].scatter(drifts[:,0], drifts[:,1], s=50)
#                axes[0].set_xlim(-limits[0], limits[0])
#                axes[0].set_ylim(-limits[1], limits[1])
                axes[0].set_xlabel('x')
                axes[0].set_ylabel('y')                
                axes[1].scatter(drifts[:,0], drifts[:,2], s=50)
                axes[1].set_xlabel('x')
                axes[1].set_ylabel('z')
                
                for ax in axes:
#                    ax.set_xlim(-1, 1)
#                    ax.set_ylim(-1, 1)
                    ax.axvline(0, color='red', ls='--')
                    ax.axhline(0, color='red', ls='--')                    
                
                fig.tight_layout()
            except Exception as e:
                print e
            
            
        return drifts
    
    def shift_images(self, psf_stack, shifts):
        kx = (np.fft.fftfreq(psf_stack.shape[0])) 
        ky = (np.fft.fftfreq(psf_stack.shape[1]))
        kz = (np.fft.fftfreq(psf_stack.shape[2]))
        kx, ky, kz = np.meshgrid(kx, ky, kz, indexing='ij')
        
        shifted_images = np.zeros_like(psf_stack)
        for i in np.arange(psf_stack.shape[3]):
            psf = psf_stack[:,:,:,i]
            ft_image = np.fft.fftn(psf)
            shift = shifts[i]
            shifted_images[:,:,:,i] = np.abs(np.fft.ifftn(ft_image*np.exp(-2j*np.pi*(kx*shift[0] + ky*shift[1] + kz*shift[2]))))
#            shifted_images.append(shifted_image)
        
        return shifted_images
class Binning(CacheCleanupModule):
    """
    Downsample 3D data (mean).
    X, Y pixels that does't fill a full bin are dropped.
    Pixels in the 3rd dimension can have a partially filled bin.
        
    Inputs
    ------
    inputName : ImageStack
    
    Outputs
    -------
    outputName : ImageStack
    
    Parameters
    ----------
    x_start : int
        Starting index in x.
    x_end : Float
        Stopping index in x.
    y_start : Float
        Starting index in y.
    y_end : Float
        Stopping index in y.
    binsize : Float
        Bin size.
    cache_bin : File
        Use file as disk cache if provided.
    """

    inputName = Input('input')
    x_start = Int(0)
    x_end = Int(-1)
    y_start = Int(0)
    y_end = Int(-1)
    #    z_start = Int(0)
    #    z_end = Int(-1)
    binsize = List([1, 1, 1], minlen=3, maxlen=3)
    cache_bin = File("binning_cache_2.bin")
    outputName = Output('binned_image')

    def _execute(self, namespace):
        self._start_time = time.time()
        ims = namespace[self.inputName]

        binsize = np.asarray(self.binsize, dtype=np.int)
        #        print (binsize)

        # unconventional, end stop in inclusive
        x_slice = np.arange(ims.data.shape[0] + 1)[slice(
            self.x_start, self.x_end, 1)]
        y_slice = np.arange(ims.data.shape[1] + 1)[slice(
            self.y_start, self.y_end, 1)]
        x_slice = x_slice[:x_slice.shape[0] // binsize[0] * binsize[0]]
        y_slice = y_slice[:y_slice.shape[0] // binsize[1] * binsize[1]]
        #        print x_slice, len(x_slice)
        #        print y_slice, len(y_slice)
        bincounts = np.asarray([
            len(x_slice) // binsize[0],
            len(y_slice) // binsize[1], -(-ims.data.shape[2] // binsize[2])
        ],
                               dtype=np.long)

        x_slice_ind = slice(x_slice[0], x_slice[-1] + 1)
        y_slice_ind = slice(y_slice[0], y_slice[-1] + 1)

        #        print (bincounts)
        new_shape = np.stack([bincounts, binsize], -1).flatten()
        #        print(new_shape)

        # need to wrap this to work for multiply color channel images
        #        binned_image = ims.data[:,:,:].reshape(new_shape)
        dtype = ims.data[:, :, 0].dtype

        #        print bincounts
        binned_image = np.memmap(self.cache_bin,
                                 dtype=dtype,
                                 mode='w+',
                                 shape=tuple(
                                     np.asarray(bincounts, dtype=np.long)))
        #        print binned_image.shape

        new_shape_one_chunk = new_shape.copy()
        new_shape_one_chunk[4] = 1
        new_shape_one_chunk[5] = -1
        #        print new_shape_one_chunk
        progress = 0.2 * ims.data.shape[2]
        #        print
        for i, f in enumerate(np.arange(0, ims.data.shape[2], binsize[2])):
            raw_data_chunk = ims.data[x_slice_ind, y_slice_ind,
                                      f:f + binsize[2]].squeeze()

            binned_image[:, :,
                         i] = raw_data_chunk.reshape(new_shape_one_chunk).mean(
                             (1, 3, 5)).squeeze()

            if (f + binsize[2] >= progress):
                binned_image.flush()
                progress += 0.2 * ims.data.shape[2]
                print("{:.2f} s. Completed binning {} of {} total images.".
                      format(time.time() - self._start_time,
                             min(f + binsize[2], ims.data.shape[2]),
                             ims.data.shape[2]))


#        print(type(binned_image))
        im = ImageStack(binned_image, titleStub=self.outputName)
        #        print(type(im.data))
        im.mdh.copyEntriesFrom(ims.mdh)
        im.mdh['Parent'] = ims.filename
        try:
            ### Metadata must be logged correctly for the measured drift to be applicable to the source image
            im.mdh['voxelsize.x'] *= binsize[0]
            im.mdh['voxelsize.y'] *= binsize[1]
            #            im.mdh['voxelsize.z'] *= binsize[2]
            if 'recipe.binning' in im.mdh.keys():
                im.mdh['recipe.binning'] = binsize * im.mdh['recipe.binning']
            else:
                im.mdh['recipe.binning'] = binsize
        except:
            pass

        namespace[self.outputName] = im
Пример #25
0
class LabelRange(Filter):
    """Asigns a unique integer label to each contiguous region in the input mask.
    Throws away all regions which are outside of given number of pixel range.
    Also uses the number of sites from a second input channel to decide if region is retained,
    retaining only those with the number sites in a given range.
    """
    inputSitesLabeled = Input(
        "sites")  # sites and the main input must have the same shape!
    minRegionPixels = Int(10)
    maxRegionPixels = Int(100)
    minSites = Int(4)
    maxSites = Int(6)
    sitesAsMaxima = Bool(False)

    def filter(self, image, imagesites):
        #from PYME.util.shmarray import shmarray
        #import multiprocessing

        if self.processFramesIndividually:
            filt_ims = []
            for chanNum in range(image.data.shape[3]):
                filt_ims.append(
                    np.concatenate([
                        np.atleast_3d(
                            self.applyFilter(
                                image.data[:, :, i,
                                           chanNum].squeeze().astype('f'),
                                imagesites.data[:, :, i,
                                                chanNum].squeeze().astype('f'),
                                chanNum, i, image))
                        for i in range(image.data.shape[2])
                    ], 2))
        else:
            filt_ims = [
                np.atleast_3d(
                    self.applyFilter(
                        image.data[:, :, :, chanNum].squeeze().astype('f'),
                        imagesites.data[:, :, :,
                                        chanNum].squeeze().astype('f'),
                        chanNum, 0, image))
                for chanNum in range(image.data.shape[3])
            ]

        im = ImageStack(filt_ims, titleStub=self.outputName)
        im.mdh.copyEntriesFrom(image.mdh)
        im.mdh['Parent'] = image.filename

        self.completeMetadata(im)

        return im

    def execute(self, namespace):
        namespace[self.outputName] = self.filter(
            namespace[self.inputName], namespace[self.inputSitesLabeled])

    def applyFilter(self, data, sites, chanNum, frNum, im):

        # siteLabels = self.recipe.namespace[self.sitesLabeled]

        mask = data > 0.5
        labs, nlabs = ndimage.label(mask)

        rSize = self.minRegionPixels
        rMax = self.maxRegionPixels

        minSites = self.minSites
        maxSites = self.maxSites

        m2 = 0 * mask
        objs = ndimage.find_objects(labs)
        for i, o in enumerate(objs):
            r = labs[o] == i + 1
            #print r.shape
            area = r.sum()
            if (area >= rSize) and (area <= rMax):
                if self.sitesAsMaxima:
                    nsites = sites[o][r].sum()
                else:
                    nsites = (np.unique(sites[o][r]) > 0).sum(
                    )  # count the unique labels (excluding label 0 which is background)
                if (nsites >= minSites) and (nsites <= maxSites):
                    m2[o] += r

        labs, nlabs = ndimage.label(m2 > 0)

        return labs

    def completeMetadata(self, im):
        im.mdh['Labelling.MinSize'] = self.minRegionPixels
        im.mdh['Labelling.MaxSize'] = self.maxRegionPixels
        im.mdh['Labelling.MinSites'] = self.minSites
        im.mdh['Labelling.MaxSites'] = self.maxSites
class RCCDriftCorrection(RCCDriftCorrectionBase):
    """
    For localization data.
    
    Performs drift correction using cross-correlation, including redundant RCC from
    Wang et al. Optics Express 2014 22:13 (Bo Huang's RCC algorithm).
    
    Runtime will vary hugely depending on size of dataset and settings.
    
    ``cache_fft`` is necessary for large datasets.
        
    Inputs
    ------
    input_for_correction : Tabular
        Dataset used to calculate drift.
    input_for_mapping : Tabular
        *Deprecated.*   Dataset to correct.
    
    Outputs
    -------
    output_drift : Tuple of arrays
        Drift results.
    output_drift_plot : Plot
        *Deprecated.*   Plot of drift results.
    output_cross_cor : ImageStack
        Cross correlation images if ``debug_cor_file`` is not blank.
    outputName : Tabular
        *Deprecated.*   Drift-corrected dataset.
    
    Parameters
    ----------
    step : Int
        Setting for image construction. Step size between images
    window : Int
        Setting for image construction. Number of frames used per image. Should be equal or larger than step size.
    binsize : Float
        Setting for image construction. Pixel size.
    flatten_z : Bool
        Setting for image construction. Ignore z information if enabled.
    tukey_size : Float
        Setting for image construction. Shape parameter for Tukey filter (``scipy.signal.tukey``).
    cache_fft : File
        Use file as disk cache if provided.
    method : String
        Redundant, mean, or direct cross-correlation.
    shift_max : Float
        Rejection threshold for RCC.
    corr_window : Float
        Size of correlation window. Frames are only compared if within this frame range. N/A for DCC.
    multiprocessing : Float
        Enables multiprocessing.
    debug_cor_file : File
        Enables debugging. Use file as disk cache if provided.
    """
    
    input_for_correction = Input('Localizations')
    input_for_mapping = Input('Localizations')
    # redundant cross-corelation, mean cross-correlation, direction cross-correlation
    step = Int(2500)
    window = Int(2500)
    binsize = Float(30)
    flatten_z = Bool()
    tukey_size = Float(0.25)

    outputName = Output('corrected_localizations')
    
    def calc_corr_drift_from_locs(self, x, y, z, t):

        # bin edges for histogram
        bx = np.arange(x.min(), x.max() + self.binsize + 1, self.binsize)
        by = np.arange(y.min(), y.max() + self.binsize + 1, self.binsize)
        bz = np.arange(z.min(), z.max() + self.binsize + 1, self.binsize)
        
        # pad bin length to odd number so image size is even
        if bx.shape[0] % 2 == 0:
            bx = np.concatenate([bx, [bx[-1] + bx[1] - bx[0]]])
        if by.shape[0] % 2 == 0:
            by = np.concatenate([by, [by[-1] + by[1] - by[0]]])
        if bz.shape[0] > 2 and bz.shape[0] % 2 == 0:
            bz = np.concatenate([bz, [bz[-1] + bz[1] - bz[0]]])
        assert (bx.shape[0] % 2 == 1) and (by.shape[0] % 2 == 1), "Ops. Image not correctly padded to even size."

        # start time of all windows, allow partial window near end of pipeline
        time_values = np.arange(t.min(), t.max() + 1, self.step)
        # 2d array, start and end time of windows
        time_values = np.stack([time_values, np.clip(time_values + self.window, None, t.max())], axis=1)        
        n_steps = time_values.shape[0]
        # center time of center for returning. last window may have different spacing
        time_values_mid = time_values.mean(axis=1)

        if (np.any(np.diff(t) < 0)): # in case pipeline is not sorted for whatever reason
            t_sort_arg = np.argsort(t)
            t = t[t_sort_arg]
            x = x[t_sort_arg]
            y = y[t_sort_arg]
            z = z[t_sort_arg]
            
        time_indexes = np.zeros_like(time_values, dtype=int)
        time_indexes[:, 0] = np.searchsorted(t, time_values[:, 0], side='left')
        time_indexes[:, 1] = np.searchsorted(t, time_values[:, 1]-1, side='right')
        
#        print('time indexes')
#        print(time_values)
#        print(time_values_mid)
#        print(time_indexes)

        # Fourier transformed (and binned) set of images to correlate against
        # one another
        # Crude way of swaping longest axis to the last for optimizing rfft performance.
        # Code changed for this is limited to this method.
        xyz = np.asarray([x, y, z])
        bxyz = np.asarray([bx, by, bz])
        dims_order = np.arange(len(xyz))
        dims_length = np.asarray([len(b) for b in bxyz])
        dims_largest_index = np.argmax(dims_length)
        dims_order[-1], dims_order[dims_largest_index] = dims_order[dims_largest_index], dims_order[-1]
        xyz = xyz[dims_order]
        bxyz = bxyz[dims_order]
        dims_length = dims_length[dims_order]
        
        # use memmap for caching if ft_cache is defined
        if self.cache_fft == "":
            ft_images = np.zeros((n_steps, dims_length[0]-1, dims_length[1]-1, (dims_length[2]-1)//2 + 1, ), dtype=np.complex)
        else:
            ft_images = np.memmap(self.cache_fft, dtype=np.complex, mode='w+', shape=(n_steps, dims_length[0]-1, dims_length[1]-1, (dims_length[2]-1)//2 + 1, ))
        
        print(ft_images.shape)
        print("{:,} bytes".format(ft_images.nbytes))
        
        print("{:.2f} s. About to start heavy lifting.".format(time.time() - self._start_time))
        
        # fill ft_images
        # if multiprocessing, can either use or not caching
        # if not multiprocessing, don't pass filenames for caching, just the memmap array is fine
        if self.multiprocessing:
            dt = ft_images.dtype
            sh = ft_images.shape
            args = [(i, xyz[:,slice(*ti)].T, bxyz, (self.cache_fft, dt, sh, i), self.tukey_size) for i, ti in enumerate(time_indexes)]

            for i, (j, res) in enumerate(self._pool.imap_unordered(calc_fft_from_locs_helper, args)):                
                if self.cache_fft == "":
                    ft_images[j] = res
                 
                if ((i+1) % (n_steps//5) == 0):
                    print("{:.2f} s. Completed calculating {} of {} total ft images.".format(time.time() - self._start_time, i+1, n_steps))

        else:
            # For each window we wish to correlate...
            for i, ti in enumerate(time_indexes):
    
                # .. we generate an image and store ft of image
                t_slice = slice(*ti)
                ft_images[i] = calc_fft_from_locs(xyz[:,t_slice].T, bxyz, filter_size=self.tukey_size)
                
                if ((i+1) % (n_steps//5) == 0):
                    print("{:.2f} s. Completed calculating {} of {} total ft images.".format(time.time() - self._start_time, i+1, n_steps))
        
        print("{:.2f} s. Finished generating ft array.".format(time.time() - self._start_time))
        print("{:,} bytes".format(ft_images.nbytes))
        
        shifts, coefs = self.calc_corr_drift_from_ft_images(ft_images)
        
        # clean up of ft_images, potentially really large array
        if isinstance(ft_images, np.memmap):
            ft_images.flush()
        del ft_images
        
#        print(shifts)
#        print(coefs)
        return time_values_mid, self.binsize * shifts[:, dims_order], coefs

    def _execute(self, namespace):
#        from PYME.util import mProfile
        
#        self._start_time = time.time()
        self.trait_setq(**{"_start_time": time.time()})
        print("Starting drift correction module.")
        
        if self.multiprocessing:
            proccess_count = np.clip(multiprocessing.cpu_count()-1, 1, None)            
            self.trait_setq(**{"_pool": multiprocessing.Pool(processes=proccess_count)})
        
        locs = namespace[self.input_for_correction]

#        mProfile.profileOn(['localisations.py', 'processing.py'])
        drift_res = self.calc_corr_drift_from_locs(locs['x'], locs['y'], locs['z'] * (0 if self.flatten_z else 1), locs['t'])
        t_shift, shifts = self.rcc(self.shift_max,  *drift_res)
        
#        mProfile.profileOff()
#        mProfile.report()

        if self.multiprocessing:
            self._pool.close()
            self._pool.join()
            
        # convert frame-to-frame drift to drift from origin
        shifts = np.cumsum(shifts, 0)

        out = tabular.mappingFilter(namespace[self.input_for_mapping])
        t_out = out['t']
        # cubic interpolate with no smoothing
        dx = interpolate.CubicSpline(t_shift, shifts[:, 0])(t_out)
        dy = interpolate.CubicSpline(t_shift, shifts[:, 1])(t_out)
        dz = interpolate.CubicSpline(t_shift, shifts[:, 2])(t_out)

        if 'dx' in out.keys():
            # getting around oddity with mappingFilter
            # addColumn adds a new column but also keeps the old column
            # __getitem__ returns the new column
            # but mappings usues the old column
            # Wrap with another level of mappingFilter so the new column becomes the 'old column'
            out.addColumn('dx', dx)
            out.addColumn('dy', dy)
            out.addColumn('dz', dz)
            out = tabular.mappingFilter(out)
#            out.mdh = namespace[self.input_localizations].mdh
            out.setMapping('x', 'x + dx')
            out.setMapping('y', 'y + dy')
            out.setMapping('z', 'z + dz')
        else:
            out.addColumn('dx', dx)
            out.addColumn('dy', dy)
            out.addColumn('dz', dz)
            out.setMapping('x', 'x + dx')
            out.setMapping('y', 'y + dy')
            out.setMapping('z', 'z + dz')

        # propagate metadata, if present
        try:
            out.mdh = locs.mdh
        except AttributeError:
            pass

        namespace[self.outputName] = out
        namespace[self.output_drift] = t_shift, shifts
        
        # non essential, only for plotting out drift data
        namespace[self.output_drift_plot] = Plot(partial(generate_drift_plot, t_shift, shifts))
        
        namespace[self.output_cross_cor] = self._cc_image
Пример #27
0
class CalculateFRCFromImages(CalculateFRCBase):
    """
    Take a pair of images and calculates the fourier shell/ring correlation (FSC / FRC).
    
    Inputs
    ------
    input_image_a : ImageStack
        First of two images.
        
    Outputs
    -------
    output_fft_images_cc : ImageStack
        Fast Fourier transform original and cross-correlation images.
    output_frc_dict : dict
        FSC/FRC results.
    output_frc_plot : Plot
        Output plot of the FSC / FRC curve.
    output_frc_raw : dict
        Complete FSC/FRC results.
    
    Parameters
    ----------
    image_b_path : File
        File path of the second of the two images.
    c_channel : int
        Color channel of the images to use.
    image_a_z : int
        Ignored unless flatten_z is True. In which case either select the z plane to use (>=0) or performs a maximum project (<0) for the first image.
    image_b_z : int
        Ignored unless flatten_z is True. In which case either select the z plane to use (>=0) or performs a maximum project (<0) for the second image.
    flatten_z : Bool
        If enabled ignores z information and only performs a FRC.
    pre_filter : string
        Methods to filter the images prior to Fourier transform.
    frc_smoothing_func : string
        Methods to smooth the FSC / FRC curve.
    cubic_smoothing  : float
        Smoothing factor for cubic spline.
    multiprocessing : Bool
        Enables multiprocessing.
    save_path : File
        (Optional) File path to save output
    
    """
    
    input_image_a = Input('input')
#    image_a_dim = Int(2)
#    image_a_index = Int(0)
#    image_b_dim = Int(2)
#    image_b_index = Int(1)
    image_b_path = File(info_text="Filepath of image to compare against. Leave blank to compare against currently opened image.")
    c_channel = Int(0)
    flatten_z = Bool(True)
    image_a_z = Int(-1)
    image_b_z = Int(-1)
    
    def execute(self, namespace):
        self._namespace = namespace
        import multiprocessing
#        from PYME.util import mProfile        
#        mProfile.profileOn(["frc.py"])
        
        if self.multiprocessing:
            proccess_count = np.clip(2, 1, multiprocessing.cpu_count()-1)
            self._pool = multiprocessing.Pool(processes=proccess_count)        
       
#        image_pair = self.generate_image_pair(mapped_pipeline)
#        ims = namespace[self.input_images]
        image_a = namespace[self.input_image_a]

        if len(self.image_b_path.strip()) == 0:
            image_b = image_a
        else:
            image_b = ImageStack(filename=self.image_b_path)
        
        self._pixel_size_in_nm = np.zeros(3, dtype=np.float)
        self._pixel_size_in_nm[0] = image_a.mdh.voxelsize.x
        self._pixel_size_in_nm[1] = image_a.mdh.voxelsize.y
        try:
            self._pixel_size_in_nm[2] = image_a.mdh.voxelsize.z
        except:
            pass
        if image_a.mdh.voxelsize.units == 'um':
            self._pixel_size_in_nm *= 1.E3
            # print(self._pixel_size_in_nm)
        
#        image_indices = [[self.image_a_dim, self.image_a_index], [self.image_b_dim, self.image_b_index]]
#        image_slices = list()
#        for i in xrange(2):
#            slices = [slice(None, None), slice(None, None)]
#            for j in xrange(2, image_indices[i][0]+1):
#                if j == image_indices[i][0]:
#                    slices.append(slice(image_indices[i][1], image_indices[i][1]+1))
#                else:
#                    slices.append(slice(None, None))
#            image_slices.append(slices)
#        
#        image_pair = [ims.data[image_slices[0]].squeeze(), ims.data[image_slices[1]].squeeze()]
        image_a_data = image_a.data[:,:,:,self.c_channel].squeeze()
        image_b_data = image_b.data[:,:,:,self.c_channel].squeeze()
        if self.flatten_z:
            print("2D mode. Slice if z index >= 0 otherwise max projection")
            if self.image_a_z >= 0:
                image_a_data = image_a_data[:,:,self.image_a_z]
            else:
                image_a_data = image_a_data.max(2)
            if self.image_b_z >= 0:
                image_b_data = image_b_data[:,:,self.image_b_z]
            else:
                image_b_data = image_b_data.max(2)
#            print(np.allclose(image_a_data, image_b_data))
        image_pair = [image_a_data, image_b_data]
#        print(image_pair[0].shape)
        
        image_pair = self.preprocess_images(image_pair)            
       
        frc_res, rawdata = self.calculate_FRC_from_images(image_pair, None)
        
        namespace[self.output_frc_dict] = frc_res
        namespace[self.output_frc_raw] = rawdata
        
        if self.multiprocessing:
            self._pool.close()
            self._pool.join()
        
#        mProfile.profileOff()
#        mProfile.report()
        
        self.save_to_file(namespace)