示例#1
0
class PSFSettings(HasTraits):
    wavelength_nm = Float(700.)
    NA = Float(1.47)
    vectorial = Bool(False)
    zernike_modes = Dict()
    zernike_modes_lower = Dict()
    phases = List([0, .5, 1, 1.5])
    four_pi = Bool(False)

    def default_traits_view(self):
        from traitsui.api import View, Item
        #from PYME.ui.custom_traits_editors import CBEditor

        return View(Item(name='wavelength_nm'),
                    Item(name='NA'),
                    Item(name='vectorial'),
                    Item(name='four_pi', label='4Pi'),
                    Item(name='zernike_modes'),
                    Item(name='zernike_modes_lower',
                         visible_when='four_pi==True'),
                    Item(name='phases',
                         visible_when='four_pi==True',
                         label='phases/pi'),
                    resizable=True,
                    buttons=['OK'])
示例#2
0
class HistByID(ModuleBase):
    """Plot histogram of a column by ID"""
    inputName = Input('measurements')
    IDkey = CStr('objectID')
    histkey = CStr('qIndex')
    outputName = Output('outGraph')
    nbins = Int(50)
    minval = Float(float('nan'))
    maxval = Float(float('nan'))

    def execute(self, namespace):
        import math
        meas = namespace[self.inputName]
        ids = meas[self.IDkey]

        uid, valsu = uniqueByID(ids, meas[self.histkey])
        if math.isnan(self.minval):
            minv = valsu.min()
        else:
            minv = self.minval
        if math.isnan(self.maxval):
            maxv = valsu.max()
        else:
            maxv = self.maxval

        import matplotlib.pyplot as plt
        plt.figure()
        plt.hist(valsu, self.nbins, range=(minv, maxv))
        plt.xlabel(self.histkey)

    @property
    def _key_choices(self):
        #try and find the available column names
        try:
            return sorted(self._parent.namespace[self.inputName].keys())
        except:
            return []

    @property
    def default_view(self):
        from traitsui.api import View, Group, Item
        from PYME.ui.custom_traits_editors import CBEditor

        return View(Item('inputName',
                         editor=CBEditor(choices=self._namespace_keys)),
                    Item('_'),
                    Item('IDkey', editor=CBEditor(choices=self._key_choices)),
                    Item('histkey',
                         editor=CBEditor(choices=self._key_choices)),
                    Item('nbins'),
                    Item('minval'),
                    Item('maxval'),
                    Item('_'),
                    Item('outputName'),
                    buttons=['OK'])
示例#3
0
class WormlikeSource(PointSource):
    kbp = Float(200)
    steplength = Float(1.0)
    lengthPerKbp = Float(10.0)
    persistLength = Float(150.0)

    #name = Str('Wormlike Chain')

    def getPoints(self):
        from PYME.simulation import wormlike2
        wc = wormlike2.wormlikeChain(self.kbp, self.steplength,
                                     self.lengthPerKbp, self.persistLength)

        return wc.xp, wc.yp, wc.zp
示例#4
0
class Scale(Filter):
    """Scale an image intensities by a constant"""

    scale = Float(1)

    def applyFilter(self, data, chanNum, i, image0):
        return self.scale * data
示例#5
0
class NormalizeMean(Filter):
    """Normalize an image so that the mean is 1"""

    offset = Float(0)

    def applyFilter(self, data, chanNum, i, image0):
        data = data - self.offset
        return data / float(data.mean())
示例#6
0
class QindexScale(ModuleBase):
    inputName = Input('qindex')
    outputName = Output('qindex-calibrated')
    qIndexkey = CStr('qIndex')
    qindexValue = Float(1.0)
    NEquivalent = Float(1.0)

    def execute(self, namespace):
        inp = namespace[self.inputName]
        mapped = tabular.mappingFilter(inp)
        qkey = self.qIndexkey
        scaled = inp[qkey]
        qigood = inp[qkey] > 0
        scaled[
            qigood] = inp[qkey][qigood] * self.NEquivalent / self.qindexValue

        self.newKey = '%sCal' % qkey
        mapped.addColumn(self.newKey, scaled)
        namespace[self.outputName] = mapped

    @property
    def _key_choices(self):
        #try and find the available column names
        try:
            return sorted(self._parent.namespace[self.inputName].keys())
        except:
            return []

    @property
    def default_view(self):
        from traitsui.api import View, Group, Item
        from PYME.ui.custom_traits_editors import CBEditor

        return View(Item('inputName',
                         editor=CBEditor(choices=self._namespace_keys)),
                    Item('_'),
                    Item('qIndexkey',
                         editor=CBEditor(choices=self._key_choices)),
                    Item('qindexValue'),
                    Item('NEquivalent'),
                    Item('_'),
                    Item('outputName'),
                    buttons=['OK'])
示例#7
0
class RuleTile(HasTraits):
    task_timeout = Float(60 * 10)
    rule_timeout = Float(60 * 10)

    def get_params(self):
        editable = self.class_editable_traits()
        return editable

    @property
    def default_view(self):
        if wx.GetApp() is None:
            return None
        from traitsui.api import View, Item

        return View([Item(tn) for tn in self.get_params()], buttons=['OK'])

    def default_traits_view(self):
        """ This is the traits stock method to specify the default view"""
        return self.default_view
示例#8
0
class WormlikeSource(PointSource):
    kbp = Float(200)
    steplength = Float(1.0)
    lengthPerKbp = Float(10.0)
    persistLength = Float(150.0)

    #name = Str('Wormlike Chain')

    def getPoints(self):
        from PYME.simulation import wormlike2
        wc = wormlike2.wormlikeChain(self.kbp, self.steplength,
                                     self.lengthPerKbp, self.persistLength)

        return wc.xp, wc.yp, wc.zp

    def genMetaData(self, mdh):
        mdh['GeneratedPoints.Source.Type'] = 'Wormlike'
        mdh['GeneratedPoints.Source.Kbp'] = self.kbp
        mdh['GeneratedPoints.Source.StepLength'] = self.steplength
        mdh['GeneratedPoints.Source.LengthPerKbp'] = self.lengthPerKbp
        mdh['GeneratedPoints.Source.PersistLength'] = self.persistLength
示例#9
0
class LabelByRegionProperty(Filter):
    """Asigns a region property to each contiguous region in the input mask.
    Optionally throws away all regions for which property is outside a given range.
    """
    regionProperty = Enum(['area', 'circularity', 'aspectratio'])
    filterByProperty = Bool(False)
    propertyMin = Float(0)
    propertyMax = Float(1e6)

    def applyFilter(self, data, chanNum, frNum, im):
        mask = data > 0.5
        labs, nlabs = ndimage.label(mask)
        rp = skimage.measure.regionprops(labs, None, cache=True)

        m2 = np.zeros_like(mask, dtype='float')
        objs = ndimage.find_objects(labs)
        for region in rp:
            oslices = objs[region.label - 1]
            r = labs[oslices] == region.label
            #print r.shape
            if self.regionProperty == 'area':
                propValue = region.area
            elif self.regionProperty == 'aspectratio':
                propValue = region.major_axis_length / region.minor_axis_length
            elif self.regionProperty == 'circularity':
                propValue = 4 * math.pi * region.area / (region.perimeter *
                                                         region.perimeter)
            if self.filterByProperty:
                if (propValue >= self.propertyMin) and (propValue <=
                                                        self.propertyMax):
                    m2[oslices] += r * propValue
            else:
                m2[oslices] += r * propValue

        return m2

    def completeMetadata(self, im):
        im.mdh['Labelling.Property'] = self.regionProperty
示例#10
0
class FlexiThreshold(Filter):
    """Chose a threshold using a range of available thresholding methods.
       Currently we can chose from: simple, fractional, otsu, isodata
    """
    method = Enum(
        'simple', 'fractional', 'otsu', 'isodata', 'li',
        'yen')  # newer skimage has minimum, mean and triangle as well
    parameter = Float(0.5)
    clipAt = Float(
        2e6
    )  # used to be 10 - increase to large value for newer PYME renderings

    def fractionalThreshold(self, data):
        N, bins = np.histogram(data, bins=5000)
        #calculate bin centres
        bin_mids = (bins[:-1])
        cN = np.cumsum(N * bin_mids)
        i = np.argmin(abs(cN - cN[-1] * (1 - self.parameter)))
        threshold = bins[i]
        return threshold

    def applyFilter(self, data, chanNum, frNum, im):

        if self.method == 'fractional':
            threshold = self.fractionalThreshold(
                np.clip(data, None, self.clipAt))
        elif self.method == 'simple':
            threshold = self.parameter
        else:
            method = getattr(skf, 'threshold_%s' % self.method)
            threshold = method(np.clip(data, None, self.clipAt))

        mask = data > threshold
        return mask

    def completeMetadata(self, im):
        im.mdh['Processing.ThresholdParameter'] = self.parameter
        im.mdh['Processing.ThresholdMethod'] = self.method
示例#11
0
class InterpolateDrift(ModuleBase):
    """
    Creates a spline interpolator from drift data. (``scipy.interpolate.UnivariateSpline``)
        
    Inputs
    ------
    input_drift_raw : Tuple of arrays
        Drift measured from localization or image dataset.
    
    Outputs
    -------
    output_drift_interpolator :
        Drift interpolator. Returns drift when called with frame number / time.
    output_drift_plot : Plot
        Plot of the original and interpolated drift.
        
    Parameters
    ----------
    degree_of_spline : int
        Degree of the smoothing spline.
    smoothing_factor : float
        Smoothing factor.
    """

    #    input_dummy = Input('input') # breaks GUI without this???
    #    load_path = File()
    degree_of_spline = Int(3)  # 1 for linear, 3 for cubic
    smoothing_factor = Float(
        -1)  # 0 for no smoothing. set to negative for UnivariateSpline defulat
    input_drift_raw = Input('drift_raw')
    output_drift_interpolator = Output('drift_interpolator')
    output_drift_plot = Output('drift_plot')

    def execute(self, namespace):
        #        data = np.load(self.load_path)
        #        tIndex = data['tIndex']
        #        drift = data['drift']
        #        namespace[self.output_drift_raw] = (tIndex, drift)
        tIndex, drift = namespace[self.input_drift_raw]

        spl = interpolate_drift(tIndex, drift, self.degree_of_spline,
                                self.smoothing_factor)

        namespace[self.output_drift_interpolator] = spl

        #        # non essential, only for plotting out drift data
        namespace[self.output_drift_plot] = Plot(
            partial(generate_drift_plot, tIndex, drift, spl))
        namespace[self.output_drift_plot].plot()
示例#12
0
class IsClose(ArithmaticFilter):
    """
    Wrapper for numpy.isclose

    Parameters:
    -----------
    abs_tolerance: Float
        absolute tolerance
    rel_tolerance: Float
        relative tolerance

    Notes:
    ------
    from numpy docs, tolerances are combined as: absolute(a - b) <= (atol + rtol * absolute(b))

    """
    abs_tolerance = Float(1e-8)
    rel_tolerance = Float(1e-5)

    def applyFilter(self, data0, data1, chanNum, i, image0):
        return np.isclose(data0,
                          data1,
                          atol=self.abs_tolerance,
                          rtol=self.rel_tolerance)
示例#13
0
class TimedSpecies(ModuleBase):
    inputName = Input('filtered')
    outputName = Output('timedSpecies')
    Species_1_Name = CStr('Species1')
    Species_1_Start = Float(0)
    Species_1_Stop = Float(1e6)

    Species_2_Name = CStr('')
    Species_2_Start = Float(0)
    Species_2_Stop = Float(0)

    Species_3_Name = CStr('')
    Species_3_Start = Float(0)
    Species_3_Stop = Float(0)

    def execute(self, namespace):
        inp = namespace[self.inputName]
        mapped = tabular.mappingFilter(inp)
        timedSpecies = self.populateTimedSpecies()

        mapped.addColumn('ColourNorm', np.ones_like(mapped['t'], 'float'))
        for species in timedSpecies:
            mapped.addColumn('p_%s' % species['name'],
                             (mapped['t'] >= species['t_start']) *
                             (mapped['t'] < species['t_end']))

        if 'mdh' in dir(inp):
            mapped.mdh = inp.mdh
            mapped.mdh['TimedSpecies'] = timedSpecies

        namespace[self.outputName] = mapped

    def populateTimedSpecies(self):
        ts = []
        if self.Species_1_Name:
            ts.append({
                'name': self.Species_1_Name,
                't_start': self.Species_1_Start,
                't_end': self.Species_1_Stop
            })

        if self.Species_2_Name:
            ts.append({
                'name': self.Species_2_Name,
                't_start': self.Species_2_Start,
                't_end': self.Species_2_Stop
            })

        if self.Species_3_Name:
            ts.append({
                'name': self.Species_3_Name,
                't_start': self.Species_3_Start,
                't_end': self.Species_3_Stop
            })

        return ts
示例#14
0
class PointFeaturesPairwiseDist(PointFeatureBase):
    """
    Create a feature vector for each point in a point-cloud using a histogram of it's distances to all other points
    
    """
    inputLocalisations = Input('localisations')
    outputName = Output('features')

    binWidth = Float(100.)  # width of the bins in nm
    numBins = Int(20)  #number of bins (starting at 0)
    threeD = Bool(True)

    normaliseRelativeDensity = Bool(
        False
    )  # divide by the sum of all radial bins. If not performed, the first principle component will likely be average density

    def execute(self, namespace):
        from PYME.Analysis.points import DistHist
        points = namespace[self.inputLocalisations]

        if self.threeD:
            x, y, z = points['x'], points['y'], points['z']
            f = np.array([
                DistHist.distanceHistogram3D(x[i], y[i], z[i], x, y, z,
                                             self.numBins, self.binWidth)
                for i in xrange(len(x))
            ])
        else:
            x, y = points['x'], points['y']
            f = np.array([
                DistHist.distanceHistogram(x[i], y[i], x, y, self.numBins,
                                           self.binWidth)
                for i in xrange(len(x))
            ])

        namespace[self.outputName] = self._process_features(points, f)
class InterpolatePSF(ModuleBase):
    """
        Interpolate PSF with RBF.
        Very stupid. Very slow. Performed on local pixels and combine by tiling.
        Only uses the first color channel
    """
    
    inputName = Input('input')
    rbf_radius = Float(250.0)
    target_voxelsize = List(Float, [100., 100., 100.])

    output_images = Output('psf_interpolated')
    
    def execute(self, namespace):        
        ims = namespace[self.inputName]
        data = ims.data[:,:,:,0]
        
        dims_original = list()
        voxelsize = [ims.mdh.voxelsize.x, ims.mdh.voxelsize.y, ims.mdh.voxelsize.z]
        for dim, dim_len in enumerate(data.shape):
            d = np.linspace(0, dim_len-1, dim_len) * voxelsize[dim] * 1E3
            d -= d.mean()        
            dims_original.append(d)
        X, Y, Z = np.meshgrid(*dims_original, indexing='ij')
        
        dims_interpolated = list()
        for dim, dim_len in enumerate(data.shape):
            tar_len = int(np.ceil((voxelsize[dim]*1E3 * dim_len) / self.target_voxelsize[dim]))
            d = np.arange(tar_len) * self.target_voxelsize[dim]
            d -= d.mean()
            dims_interpolated.append(d)
        
        X_interp, Y_interp, Z_interp = np.meshgrid(*dims_interpolated, indexing='ij')
        pts_interp = zip(*[X_interp.flatten(), Y_interp.flatten(), Z_interp.flatten()])
        
        results = np.zeros(X_interp.size)
        for i, pt in enumerate(pts_interp):
            results[i] = self.InterpolateAt(pt, X, Y, Z, data[:,:,:])
            if i % 100 == 0:
                print("{} out of {} completed.".format(i, results.shape[0]))
        
        results = results.reshape(len(dims_interpolated[0]), len(dims_interpolated[1]), len(dims_interpolated[2]))
#        return results
        new_mdh = None
        try:
            new_mdh = MetaDataHandler.NestedClassMDHandler(ims.mdh)            
            new_mdh["voxelsize.x"] = self.target_voxelsize[0] * 1E-3
            new_mdh["voxelsize.y"] = self.target_voxelsize[1] * 1E-3
            new_mdh["voxelsize.z"] = self.target_voxelsize[2] * 1E-3
            new_mdh['Interpolation.Method'] = 'RBF'
            new_mdh['Interpolation.RbfRadius'] = self.rbf_radius
        except Exception as e:
            print(e)
        namespace[self.output_images] = ImageStack(data=results, mdh=new_mdh)
        
    def InterpolateAt(self, pt, X, Y, Z, data, radius=250.):
        X_subset, Y_subset, Z_subset, data_subset = self.GetPointsInNeighbourhood(pt, X, Y, Z, data, radius)
        rbf = interpolate.Rbf(X_subset, Y_subset, Z_subset, data_subset, function="cubic", smooth=1E3)#, norm=euclidean_norm_numpy)
    #     print pt
        return rbf(*pt)
    
    def GetPointsInNeighbourhood(self, center, X, Y, Z, data, radius=250.):
        distance = np.sqrt((X-center[0])**2 + (Y-center[1])**2 + (Z-center[2])**2)
    #     print distance.shape
        mask = distance < 250.
        
        return X[mask], Y[mask], Z[mask], data[mask]
示例#16
0
class Pow(Filter):
    "Raise an image to a given power (can be fractional for sqrt)"
    power = Float(2)

    def applyFilter(self, data, chanNum, i, image0):
        return np.power(data, self.power)
示例#17
0
class EnsembleFitROIs(
        ModuleBase):  # Note that this should probably be moved somewhere else
    inputName = Input('ROIs')

    fit_type = CStr('LorentzianConvolvedSolidSphere_ensemblePSF')
    ensemble_parameter_guess = Float(50.)
    hold_ensemble_parameter_constant = Bool(False)

    outputName = Output('fit_results')

    def execute(self, namespace):

        inp = namespace[self.inputName]

        # generate RegionHandler from tables
        handler = RegionHandler()
        handler._load_from_list(inp)

        fit_class = region_fitters.fitters[self.fit_type]
        fitter = fit_class(handler)

        if self.hold_ensemble_parameter_constant:
            fitter.fit_profiles(self.ensemble_parameter_guess)
        else:
            fitter.ensemble_fit(self.ensemble_parameter_guess)

        res = tabular.RecArraySource(fitter.results)

        # propagate metadata, if present
        res.mdh = MetaDataHandler.NestedClassMDHandler(
            getattr(inp, 'mdh', None))

        res.mdh['EnsembleFitROIs.FitType'] = self.fit_type
        res.mdh[
            'EnsembleFitROIs.EnsembleParameterGuess'] = self.ensemble_parameter_guess
        res.mdh[
            'EnsembleFitROIs.HoldEnsembleParamConstant'] = self.hold_ensemble_parameter_constant

        namespace[self.outputName] = res

    @property
    def _fitter_choices(self):
        return list(region_fitters.ensemble_fitters.keys())  #FIXME???

    @property
    def default_view(self):
        from traitsui.api import View, Item
        from PYME.ui.custom_traits_editors import CBEditor

        return View(Item('inputName',
                         editor=CBEditor(choices=self._namespace_keys)),
                    Item('_'),
                    Item('fit_type',
                         editor=CBEditor(choices=self._fitter_choices)),
                    Item('_'),
                    Item('ensemble_parameter_guess'),
                    Item('_'),
                    Item('hold_ensemble_parameter_constant'),
                    Item('_'),
                    Item('outputName'),
                    buttons=['OK'])

    @property
    def pipeline_view(self):
        from traitsui.api import View, Item
        from PYME.ui.custom_traits_editors import CBEditor

        return View(
            Item('fit_type', editor=CBEditor(choices=self._fitter_choices)),
            Item('_'),
            Item('ensemble_parameter_guess'),
            Item('_'),
            Item('hold_ensemble_parameter_constant'),
        )

    @property
    def dsview_view(self):
        from traitsui.api import View, Item
        from PYME.ui.custom_traits_editors import CBEditor

        return View(Item('fit_type',
                         editor=CBEditor(choices=self._fitter_choices)),
                    Item('_'),
                    Item('ensemble_parameter_guess'),
                    Item('_'),
                    Item('hold_ensemble_parameter_constant'),
                    buttons=['OK'])
示例#18
0
class OctreeRenderLayer(TriangleRenderLayer):
    """
    Layer for viewing octrees. Takes in an octree, splits the faces into triangles, and then uses the rendering engines
    from PYME.LMVis.layers.triangle_mesh.
    """
    # Additional properties (the rest are inherited from TriangleRenderLayer)
    depth = Int(3, desc='Depth at which to render Octree. Set to -1 for dynamic depth rendering.')
    density = Float(0.0, desc='Minimum density of octree node to display.')
    min_points = Int(10, desc='Number of points/node to truncate octree at')

    def __init__(self, pipeline, method='wireframe', dsname='', context=None, **kwargs):
        TriangleRenderLayer.__init__(self, pipeline, method, dsname, context, **kwargs)

        self.on_trait_change(self.update, 'depth')
        self.on_trait_change(self.update, 'density')
        self.on_trait_change(self.update, 'min_points')
        
    @property
    def _ds_class(self):
        from PYME.experimental import octree
        return (octree.Octree, octree.PyOctree)

    def update_from_datasource(self, ds):
        """
        Opens an octree. Subdivides the faces into triangles. Feeds the triangle points/normals to update_data.

        Parameters
        ----------
        ds :
            Octree (see PYME.experimental.octree)

        Returns
        -------
        None
        """
        nodes = ds._nodes[ds._nodes[ds._nodes['parent']]['nPoints'] >= float(self.min_points)]
        
        if self.depth > 0:
            # Grab the nodes at the specified depth
            nodes = nodes[nodes['depth'] == self.depth]
            box_sizes = np.ones((nodes.shape[0], 3))*ds.box_size(self.depth)
            node_density = 1.*nodes['nPoints']/np.prod(box_sizes,axis=1)
            nodes = nodes[node_density >= self.density]
            box_sizes = np.ones((nodes.shape[0], 3))*ds.box_size(self.depth)
            alpha = nodes['nPoints']/box_sizes[:,0]
        elif self.depth == 0:
            # plot all bins
            nodes = nodes[nodes['nPoints'] >= 1]
            box_sizes = np.vstack(ds.box_size(nodes['depth'])).T

            alpha = nodes['nPoints'] * ((2 ** nodes['depth'])**3)
        else:
            # Plot leaf nodes
            nodes = nodes[(np.sum(nodes['children'],axis=1) == 0)&(nodes['depth'] > 0)]
            box_sizes = np.vstack(ds.box_size(nodes['depth'])).T
            
            alpha = nodes['nPoints']*((2.0**nodes['depth']))**3

        if len(nodes) > 0:
            c = nodes['centre']  # center
            shifts = (box_sizes[:,None]*OCT_SHIFT[None,:])*0.5
            v = (c[:,None,:] + shifts)

            #
            #     z
            #     ^
            #     |
            #    v4 ----------v6
            #    /|           /|
            #   / |          / |
            #  v5----------v7  |
            #  |  |    c    |  |
            #  | v0---------|-v2
            #  | /          | /
            #  v1-----------v3---> y
            #  /
            # x
            #
            # Now note that the counterclockwise triangles (when viewed straight-on) formed along the faces of the cube are:
            #
            # v0 v2 v1
            # v0 v1 v5
            # v0 v5 v4
            # v0 v6 v2
            # v0 v4 v6
            # v1 v2 v3
            # v1 v3 v7
            # v1 v7 v5
            # v2 v6 v7
            # v2 v7 v3
            # v4 v5 v6
            # v5 v7 v6

            # Counterclockwise triangles (when viewed straight-on) formed along
            # the faces of an octree box
            t0 = np.vstack(v[:,[0,0,0,0,0,1,1,1,2,2,4,5],:])
            t1 = np.vstack(v[:,[2,1,5,6,4,2,3,7,6,7,5,7],:])
            t2 = np.vstack(v[:,[1,5,4,2,6,3,7,5,7,3,6,6],:])

            x, y, z = np.hstack([t0,t1,t2]).reshape(-1, 3).T  # positions

            # Now we create the normal as the cross product
            tn = np.cross((t2-t1),(t0-t1))

            # We copy the normals 3 times per triangle to get 3x(3N) normals to match the vertices shape
            xn, yn, zn = np.repeat(tn.T, 3, axis=1)  # normals

            # Color is fixed constnat for octree
            c = np.ones(len(x))
            clim = [0, 1]
            
            alpha = self.alpha*alpha/alpha.max()
            alpha = (alpha[None,:]*np.ones(12)[:,None])
            alpha = np.repeat(alpha.ravel(), 3)
            print('Octree scaled alpha range: %g, %g' % (alpha.min(), alpha.max()))

            cmap = getattr(cm, self.cmap)

            # Do we have coordinates? Concatenate into vertices.
            if x is not None and y is not None and z is not None:
                vertices = np.vstack((x.ravel(), y.ravel(), z.ravel()))
                self._vertices = vertices.T.ravel().reshape(len(x.ravel()), 3)

                if not xn is None:
                    self._normals = np.vstack((xn.ravel(), yn.ravel(), zn.ravel())).T.ravel().reshape(len(x.ravel()), 3)
                else:
                    self._normals = -0.69 * np.ones(self._vertices.shape)

                self._bbox = np.array([x.min(), y.min(), z.min(), x.max(), y.max(), z.max()])
            else:
                self._bbox = None

            if clim is not None and c is not None and cmap is not None:
                cs_ = ((c - clim[0]) / (clim[1] - clim[0]))
                cs = cmap(cs_)

                if self.method in ['flat', 'tessel']:
                    alpha = cs_ * alpha
                
                cs[:, 3] = alpha
                
                if self.method == 'tessel':
                    cs = np.power(cs, 0.333)

                self._colors = cs.ravel().reshape(len(c), 4)
            else:
                # cs = None
                if not self._vertices is None:
                    self._colors = np.ones((self._vertices.shape[0], 4), 'f')

            print('Colors: {}'.format(self._colors))
                
            self._alpha = alpha
            self._color_map = cmap
            self._color_limit = clim
        else:
            print('No nodes for density {0}, depth {1}'.format(self.density, self.depth))
        

    @property
    def default_view(self):
        from traitsui.api import View, Item, Group, InstanceEditor, EnumEditor
        from PYME.ui.custom_traits_editors import HistLimitsEditor, CBEditor

        return View([#Group([
                     Item('dsname', label='Data', editor=EnumEditor(name='_datasource_choices')),
                     Item('method'),
                     Item('depth'),Item('min_points'),
                     #Item('vertexColour', editor=EnumEditor(name='_datasource_keys'), label='Colour'),
                     #Group([Item('clim', editor=HistLimitsEditor(data=self._get_cdata), show_label=False), ]),
                     Group([Item('cmap', label='LUT'), Item('alpha')])], )
示例#19
0
class ImageSource(PointSource):
    image = WRDictEnum(image.openImages)
    points_per_pixel = Float(0.1)
    #name = Str('Density Image')
    #foo = Enum([1,2,3,4])

    helpInfo = {
        'points_per_pixel':
        '''
Select average number of points (dye molecules or docking sites) per pixel in image regions where the density values are 1.
The number is a floating point fraction, e.g. 0.1, and shouldn't exceed 1. It is used for Monte-Carlo rejection of positions
and larger values (>~0.2) will result in images which have visible pixel-grid structure because the Monte-Carlo sampling
is no longer a good approximation to random sampling over the grid. If this is a problem for your application / you can't
get high enough density without a high acceptance fraction, use an up-sampled source image with a smaller pixel size.
''',
        'image':
        '''
Select an image from the list of open images.
Note that you need to open or generate the source image you want to use so that this
list is not empty. The image will be normalised for the purpose of the simulation,
with its maximum set to 1. It describes the density of markers in the simulated sample,
where values of 1 have a density of markers as given by the `points per pixel` parameter, i.e.
in the Monte-Carlo sampling the acceptance probability = image*points_per_pixel. Smaller
density values therefore give rise to proportionally fewer markers per pixel.
''',
    }

    def helpStr(self, name):
        def cleanupHelpStr(str):
            return str.strip().replace('\n', ' ').replace('\r', '')

        return cleanupHelpStr(self.helpInfo[name])

    def default_traits_view(self):
        from traitsui.api import View, Item

        traits_view = View(
            Item('points_per_pixel',
                 help=self.helpStr('points_per_pixel'),
                 tooltip='mean number of marker points per pixel'),
            Item('image',
                 help=self.helpStr('image'),
                 tooltip=
                 'select the marker density image from the list of open images'
                 ),
            buttons=['OK', 'Help'])

        return traits_view

    def getPoints(self):
        from PYME.simulation import locify
        # print((self.image))  # if still needed should be replaced by a logging statement

        try:
            im = image.openImages[self.image]
        except KeyError:
            # no image of that name:
            # If uncaught this will pop up in the error dialog from 'Computation in progress', so shouldn't need
            # an explicit dialog / explicit handing. TODO - do we need an error subclass - e.g. UserError or ParameterError
            # which the error dialog treats differently to more generic errors so as to make it clear that it's something
            # the user has done wrong rather than a bug???
            raise UserError(
                'No open image found with name: "%s", please set "image" property of ImageSource to a valid image name\nThis must be an image which is already open.\n\n'
                % self.image)

        #import numpy as np
        d = im.data[:, :, 0, 0].astype('f')

        #normalise the image
        d = d / d.max()

        return locify.locify(d,
                             pixelSize=im.pixelSize,
                             pointsPerPixel=self.points_per_pixel)

    def get_bounds(self):
        return image.openImages[self.image].imgBounds

    def refresh_choices(self):
        ed = self.trait('image').editor

        if ed:
            try:
                ed._values_changed()
            except TypeError:
                # TODO - why can _values_changed be None??
                # is there a better way to handle/avoid this?
                pass
            #super( HasTraits, self ).configure_traits(*args, **kwargs)

#    traits_view = View( Item('points_per_pixel'),
#                        Item('image'),
##                        Item( 'image',
##                              label='Image',
##                              editor =
##                                  EnumEditor(values={'foo':0, 'bar' : 1}),#image.openImages),
##                              ),
#                        buttons = ['OK'])

    def genMetaData(self, mdh):
        mdh['GeneratedPoints.Source.Type'] = 'Image'
        mdh['GeneratedPoints.Source.PointsPerPixel'] = self.points_per_pixel
        mdh['GeneratedPoints.Source.Image'] = self.image
示例#20
0
class TriangleRenderLayer(EngineLayer):
    """
    Layer for viewing triangle meshes.
    """
    # properties to show in the GUI. Note that we also inherit 'visible' from BaseLayer
    vertexColour = CStr('constant', desc='Name of variable used to colour our points')
    cmap = Enum(*cm.cmapnames, default='gist_rainbow', desc='Name of colourmap used to colour faces')
    clim = ListFloat([0, 1], desc='How our variable should be scaled prior to colour mapping')
    alpha = Float(1.0, desc='Face tranparency')
    method = Enum(*ENGINES.keys(), desc='Method used to display faces')
    normal_mode = Enum(['Per vertex', 'Per face'])
    dsname = CStr('output', desc='Name of the datasource within the pipeline to use as a source of triangles (should be a TriangularMesh object)')
    _datasource_choices = List()
    _datasource_keys = List()

    def __init__(self, pipeline, method='wireframe', dsname='', context=None, **kwargs):
        EngineLayer.__init__(self, context=context, **kwargs)
        self._pipeline = pipeline
        self.engine = None
        self.cmap = 'gist_rainbow'

        self.x_key = 'x'  # TODO - make these traits?
        self.y_key = 'y'
        self.z_key = 'z'

        self.xn_key = 'xn'
        self.yn_key = 'yn'
        self.zn_key = 'zn'

        self._bbox = None

        # define a signal so that people can be notified when we are updated (currently used to force a redraw when
        # parameters change)
        self.on_update = dispatch.Signal()

        # define responses to changes in various traits
        self.on_trait_change(self._update, 'vertexColour')
        self.on_trait_change(lambda: self.on_update.send(self), 'visible')
        self.on_trait_change(self.update, 'cmap, clim, alpha, dsname, normal_mode')
        self.on_trait_change(self._set_method, 'method')

        # update any of our traits which were passed as command line arguments
        self.set(**kwargs)

        # update datasource and method
        self.dsname = dsname
        if self.method == method:
            #make sure we still call _set_method even if we start with the default method
            self._set_method()
        else:
            self.method = method

        # if we were given a pipeline, connect ourselves to the onRebuild signal so that we can automatically update
        # ourselves
        if not self._pipeline is None:
            self._pipeline.onRebuild.connect(self.update)

    @property
    def datasource(self):
        """
        Return the datasource we are connected to (does not go through the pipeline for triangles_mesh).
        """
        return self._pipeline.get_layer_data(self.dsname)
        #return self.datasource
    
    @property
    def _ds_class(self):
        # from PYME.experimental import triangle_mesh
        from PYME.experimental import _triangle_mesh as triangle_mesh
        return triangle_mesh.TrianglesBase

    def _set_method(self):
        self.engine = ENGINES[self.method](self._context)
        self.update()

    def _get_cdata(self):
        try:
            cdata = self.datasource[self.vertexColour]
        except (KeyError, TypeError):
            cdata = np.array([0, 1])

        return cdata

    def _update(self, *args, **kwargs):
        #pass
        cdata = self._get_cdata()
        self.clim = [float(cdata.min()), float(cdata.max())]
        self.update(*args, **kwargs)

    def update(self, *args, **kwargs):
        self._datasource_choices = [k for k, v in self._pipeline.dataSources.items() if isinstance(v, self._ds_class)]
        
        if not self.datasource is None:
            dks = ['constant',]
            if hasattr(self.datasource, 'keys'):
                 dks = dks + sorted(self.datasource.keys())
            self._datasource_keys = dks
        
        if not (self.engine is None or self.datasource is None):
            print('lw update')
            self.update_from_datasource(self.datasource)
            self.on_update.send(self)

    @property
    def bbox(self):
        return self._bbox

    def update_from_datasource(self, ds):
        """
        Pulls vertices/normals from a binary STL file. See PYME.IO.FileUtils.stl for more info. Calls update_data on the input.

        Parameters
        ----------
        ds :
            PYME.experimental.triangular_mesh.TriangularMesh object

        Returns
        -------
        None
        """
        #t = ds.vertices[ds.faces]
        #n = ds.vertex_normals[ds.faces]
        
        x, y, z = ds.vertices[ds.faces].reshape(-1, 3).T
        
        if self.normal_mode == 'Per vertex':
            xn, yn, zn = ds.vertex_normals[ds.faces].reshape(-1, 3).T
        else:
            xn, yn, zn = np.repeat(ds.face_normals.T, 3, axis=1)
            
        if self.vertexColour in ['', 'constant']:
            c = np.ones(len(x))
            clim = [0, 1]
        #elif self.vertexColour == 'vertex_index':
        #    c = np.arange(0, len(x))
        else:
            c = ds[self.vertexColour][ds.faces].ravel()
            clim = self.clim

        cmap = getattr(cm, self.cmap)
        alpha = float(self.alpha)

        # Do we have coordinates? Concatenate into vertices.
        if x is not None and y is not None and z is not None:
            vertices = np.vstack((x.ravel(), y.ravel(), z.ravel()))
            self._vertices = vertices.T.ravel().reshape(len(x.ravel()), 3)

            if not xn is None:
                self._normals = np.vstack((xn.ravel(), yn.ravel(), zn.ravel())).T.ravel().reshape(len(x.ravel()), 3)
            else:
                self._normals = -0.69 * np.ones(self._vertices.shape)

            self._bbox = np.array([x.min(), y.min(), z.min(), x.max(), y.max(), z.max()])
        else:
            self._bbox = None

        # TODO: This temporarily sets all triangles to the color red. User should be able to select color.
        if c is None:
            c = np.ones(self._vertices.shape[0]) * 255  # vector of pink
            
        

        if clim is not None and c is not None and cmap is not None:
            cs_ = ((c - clim[0]) / (clim[1] - clim[0]))
            cs = cmap(cs_)

            if self.method in ['flat', 'tessel']:
                alpha = cs_ * alpha
            
            cs[:, 3] = alpha
            
            if self.method == 'tessel':
                cs = np.power(cs, 0.333)

            self._colors = cs.ravel().reshape(len(c), 4)
        else:
            # cs = None
            if not self._vertices is None:
                self._colors = np.ones((self._vertices.shape[0], 4), 'f')
            
        self._alpha = alpha
        self._color_map = cmap
        self._color_limit = clim


    def get_vertices(self):
        return self._vertices

    def get_normals(self):
        return self._normals

    def get_colors(self):
        return self._colors

    def get_color_map(self):
        return self._color_map

    @property
    def colour_map(self):
        return self._color_map

    def get_color_limit(self):
        return self._color_limit

    @property
    def default_view(self):
        from traitsui.api import View, Item, Group, InstanceEditor, EnumEditor
        from PYME.ui.custom_traits_editors import HistLimitsEditor, CBEditor

        return View([Group([Item('dsname', label='Data', editor=EnumEditor(name='_datasource_choices')), ]),
                     Item('method'),
                     Item('normal_mode', visible_when='method=="shaded"'),
                     Item('vertexColour', editor=EnumEditor(name='_datasource_keys'), label='Colour'),
                     Group([Item('clim', editor=HistLimitsEditor(data=self._get_cdata), show_label=False), ], visible_when='vertexColour != "constant"'),
                     Group([Item('cmap', label='LUT'),
                            Item('alpha', visible_when='method in ["flat", "tessel"]')
                            ])
                     ], )
        # buttons=['OK', 'Cancel'])

    def default_traits_view(self):
        return self.default_view
示例#21
0
class PointCloudRenderLayer(EngineLayer):
    """
    A layer for viewing point-cloud data, using one of 3 engines (indicated above)
    
    """

    # properties to show in the GUI. Note that we also inherit 'visible' from BaseLayer
    vertexColour = CStr('', desc='Name of variable used to colour our points')
    point_size = Float(30.0, desc='Rendered size of the points in nm')
    cmap = Enum(*cm.cmapnames,
                default='gist_rainbow',
                desc='Name of colourmap used to colour points')
    clim = ListFloat(
        [0, 1],
        desc='How our variable should be scaled prior to colour mapping')
    alpha = Float(1.0, desc='Point tranparency')
    method = Enum(*ENGINES.keys(), desc='Method used to display points')
    dsname = CStr(
        'output',
        desc=
        'Name of the datasource within the pipeline to use as a source of points'
    )
    _datasource_keys = List()
    _datasource_choices = List()

    def __init__(self,
                 pipeline,
                 method='points',
                 dsname='',
                 context=None,
                 **kwargs):
        EngineLayer.__init__(self, context=context, **kwargs)
        self._pipeline = pipeline
        self.engine = None
        self.cmap = 'gist_rainbow'

        self.x_key = 'x'  #TODO - make these traits?
        self.y_key = 'y'
        self.z_key = 'z'

        self.xn_key = 'xn'
        self.yn_key = 'yn'
        self.zn_key = 'zn'

        self._bbox = None

        # define a signal so that people can be notified when we are updated (currently used to force a redraw when
        # parameters change)
        self.on_update = dispatch.Signal()

        # define responses to changes in various traits
        self.on_trait_change(self._update, 'vertexColour')
        self.on_trait_change(lambda: self.on_update.send(self), 'visible')
        self.on_trait_change(self.update,
                             'cmap, clim, alpha, dsname, point_size')
        self.on_trait_change(self._set_method, 'method')

        # update any of our traits which were passed as command line arguments
        self.set(**kwargs)

        # update datasource name and method
        #logger.debug('Setting dsname and method')
        self.dsname = dsname
        self.method = method

        self._set_method()

        # if we were given a pipeline, connect ourselves to the onRebuild signal so that we can automatically update
        # ourselves
        if not self._pipeline is None:
            self._pipeline.onRebuild.connect(self.update)

    @property
    def datasource(self):
        """
        Return the datasource we are connected to (through our dsname property).
        """
        return self._pipeline.get_layer_data(self.dsname)

    def _set_method(self):
        #logger.debug('Setting layer method to %s' % self.method)
        self.engine = ENGINES[self.method](self._context)
        self.update()

    def _get_cdata(self):
        try:
            cdata = self.datasource[self.vertexColour]
        except KeyError:
            cdata = np.array([0, 1])

        return cdata

    def _update(self, *args, **kwargs):
        cdata = self._get_cdata()
        self.clim = [float(cdata.min()), float(cdata.max())]
        #self.update(*args, **kwargs)

    def update(self, *args, **kwargs):
        print('lw update')
        self._datasource_choices = self._pipeline.layer_data_source_names
        if not self.datasource is None:
            self._datasource_keys = sorted(self.datasource.keys())

        if not (self.engine is None or self.datasource is None):
            self.update_from_datasource(self.datasource)
            self.on_update.send(self)

    @property
    def bbox(self):
        return self._bbox

    def update_from_datasource(self, ds):
        x, y = ds[self.x_key], ds[self.y_key]

        if not self.z_key is None:
            try:
                z = ds[self.z_key]
            except KeyError:
                z = 0 * x
        else:
            z = 0 * x

        if not self.vertexColour == '':
            c = ds[self.vertexColour]
        else:
            c = 0 * x

        if self.xn_key in ds.keys():
            xn, yn, zn = ds[self.xn_key], ds[self.yn_key], ds[self.zn_key]
            self.update_data(x,
                             y,
                             z,
                             c,
                             cmap=getattr(cm, self.cmap),
                             clim=self.clim,
                             alpha=self.alpha,
                             xn=xn,
                             yn=yn,
                             zn=zn)
        else:
            self.update_data(x,
                             y,
                             z,
                             c,
                             cmap=getattr(cm, self.cmap),
                             clim=self.clim,
                             alpha=self.alpha)

    def update_data(self,
                    x=None,
                    y=None,
                    z=None,
                    colors=None,
                    cmap=None,
                    clim=None,
                    alpha=1.0,
                    xn=None,
                    yn=None,
                    zn=None):
        self._vertices = None
        self._normals = None
        self._colors = None
        self._color_map = None
        self._color_limit = 0
        self._alpha = 0
        if x is not None and y is not None and z is not None:
            vertices = np.vstack((x.ravel(), y.ravel(), z.ravel()))
            vertices = vertices.T.ravel().reshape(len(x.ravel()), 3)

            if not xn is None:
                normals = np.vstack(
                    (xn.ravel(), yn.ravel(),
                     zn.ravel())).T.ravel().reshape(len(x.ravel()), 3)
            else:
                normals = -0.69 * np.ones(vertices.shape)

            self._bbox = np.array(
                [x.min(), y.min(),
                 z.min(), x.max(),
                 y.max(), z.max()])
        else:
            vertices = None
            normals = None
            self._bbox = None

        if clim is not None and colors is not None and clim is not None:
            cs_ = ((colors - clim[0]) / (clim[1] - clim[0]))
            cs = cmap(cs_)
            cs[:, 3] = alpha

            cs = cs.ravel().reshape(len(colors), 4)
        else:
            #cs = None
            if not vertices is None:
                cs = np.ones((vertices.shape[0], 4), 'f')
            else:
                cs = None
            color_map = None
            color_limit = None

        self.set_values(vertices, normals, cs, cmap, clim, alpha)

    def set_values(self,
                   vertices=None,
                   normals=None,
                   colors=None,
                   color_map=None,
                   color_limit=None,
                   alpha=None):
        if vertices is not None:
            self._vertices = vertices
        if normals is not None:
            self._normals = normals
        if color_map is not None:
            self._color_map = color_map
        if colors is not None:
            self._colors = colors
        if color_limit is not None:
            self._color_limit = color_limit
        if alpha is not None:
            self._alpha = alpha

    def get_vertices(self):
        return self._vertices

    def get_normals(self):
        return self._normals

    def get_colors(self):
        return self._colors

    def get_color_map(self):
        return self._color_map

    @property
    def colour_map(self):
        return self._color_map

    def get_color_limit(self):
        return self._color_limit

    @property
    def default_view(self):
        from traitsui.api import View, Item, Group, InstanceEditor, EnumEditor, TextEditor
        from PYME.ui.custom_traits_editors import HistLimitsEditor, CBEditor

        return View([
            Group([
                Item('dsname',
                     label='Data',
                     editor=EnumEditor(name='_datasource_choices')),
            ]),
            Item('method'),
            Item(
                'vertexColour',
                editor=EnumEditor(name='_datasource_keys'),
                label='Colour',
                visible_when='cmap not in ["R", "G", "B", "C", "M","Y", "K"]'),
            Group(
                [
                    Item('clim',
                         editor=HistLimitsEditor(data=self._get_cdata,
                                                 update_signal=self.on_update),
                         show_label=False),
                ],
                visible_when='cmap not in ["R", "G", "B", "C", "M","Y", "K"]'),
            Group(
                Item('cmap', label='LUT'),
                Item('alpha',
                     visible_when=
                     "method in ['pointsprites', 'transparent_points']",
                     editor=TextEditor(auto_set=False,
                                       enter_set=True,
                                       evaluate=float)),
                Item('point_size',
                     editor=TextEditor(auto_set=False,
                                       enter_set=True,
                                       evaluate=float)))
        ])
        #buttons=['OK', 'Cancel'])

    def default_traits_view(self):
        return self.default_view
示例#22
0
class ImageRenderLayer(EngineLayer):
    """
    Layer for viewing images.
    """
    # properties to show in the GUI. Note that we also inherit 'visible' from BaseLayer
    cmap = Enum(*cm.cmapnames, default='gray', desc='Name of colourmap used to colour faces')
    clim = ListFloat([0, 1], desc='How our data should be scaled prior to colour mapping')
    alpha = Float(1.0, desc='Tranparency')
    method = Enum(*ENGINES.keys(), desc='Method used to display image')
    dsname = CStr('output', desc='Name of the datasource within the pipeline to use as an image')
    channel = Int(0)
    slice = Int(0)
    z_pos = Float(0)
    _datasource_choices = List()
    _datasource_keys = List()

    def __init__(self, pipeline, method='image', dsname='', display_opts=None, context=None, **kwargs):
        EngineLayer.__init__(self, context=context, **kwargs)
        self._pipeline = pipeline
        self.engine = None
        self.cmap = 'gray'

        self._bbox = None
        self._do = display_opts #a dh5view display_options instance - if provided, this over-rides the the clim, cmap properties
        
        self._im_key = None

        # define a signal so that people can be notified when we are updated (currently used to force a redraw when
        # parameters change)
        self.on_update = dispatch.Signal()

        # define responses to changes in various traits
        #self.on_trait_change(self._update, 'vertexColour')
        self.on_trait_change(lambda: self.on_update.send(self), 'visible')
        self.on_trait_change(self.update, 'cmap, clim, alpha, dsname')
        self.on_trait_change(self._set_method, 'method')

        # update any of our traits which were passed as command line arguments
        self.set(**kwargs)

        # update datasource and method
        self.dsname = dsname
        if self.method == method:
            #make sure we still call _set_method even if we start with the default method
            self._set_method()
        else:
            self.method = method

        # if we were given a pipeline, connect ourselves to the onRebuild signal so that we can automatically update
        # ourselves
        if (not self._pipeline is None) and hasattr(pipeline, 'onRebuild'):
            self._pipeline.onRebuild.connect(self.update)

    @property
    def datasource(self):
        """
        Return the datasource we are connected to (does not go through the pipeline for triangles_mesh).
        """
        try:
            return self._pipeline.get_layer_data(self.dsname)
        except AttributeError:
            #fallback if pipeline is a dictionary
            return self._pipeline[self.dsname]
        #return self.datasource
    
    @property
    def _ds_class(self):
        # from PYME.experimental import triangle_mesh
        from PYME.IO import image
        return image.ImageStack

    def _set_method(self):
        self.engine = ENGINES[self.method](self._context)
        self.update()


    # def _update(self, *args, **kwargs):
    #     #pass
    #     cdata = self._get_cdata()
    #     self.clim = [float(cdata.min()), float(cdata.max())]
    #     self.update(*args, **kwargs)

    def update(self, *args, **kwargs):
        try:
            self._datasource_choices = [k for k, v in self._pipeline.dataSources.items() if isinstance(v, self._ds_class)]
        except AttributeError:
            self._datasource_choices = [k for k, v in self._pipeline.items() if
                                        isinstance(v, self._ds_class)]
        
        if not (self.engine is None or self.datasource is None):
            print('lw update')
            self.update_from_datasource(self.datasource)
            self.on_update.send(self)

    @property
    def bbox(self):
        return self._bbox
    
    def sync_to_display_opts(self, do=None):
        if (do is None):
            if not (self._do is None):
                do = self._do
            else:
                return

        o = do.Offs[self.channel]
        g = do.Gains[self.channel]
        clim = [o, o + 1.0 / g]

        cmap = do.cmaps[self.channel].name
        visible = do.show[self.channel]
        
        self.set(clim=clim, cmap=cmap, visible=visible)
        

    def update_from_datasource(self, ds):
        """

        Parameters
        ----------
        ds :
            PYME.IO.image.ImageStack object

        Returns
        -------
        None
        """

        
        #if self._do is not None:
            # Let display options (if provied) over-ride our settings (TODO - is this the right way to do this?)
        #    o = self._do.Offs[self.channel]
        #    g = self._do.Gains[self.channel]
        #    clim = [o, o + 1.0/g]
            #self.clim = clim
            
        #    cmap = self._do.cmaps[self.channel]
            #self.visible = self._do.show[self.channel]
        #else:
        
        clim = self.clim
        cmap = getattr(cm, self.cmap)
            
        alpha = float(self.alpha)
        
        c0, c1 = clim
        
        im_key = (self.dsname, self.slice, self.channel)
        
        if not self._im_key == im_key:
            self._im_key = im_key
            self._im = ds.data[:,:,self.slice, self.channel].astype('f4')# - c0)/(c1-c0)
        
            x0, y0, x1, y1, _, _ = ds.imgBounds.bounds

            self._bbox = np.array([x0, y0, 0, x1, y1, 0])
        
            self._bounds = [x0, y0, x1, y1]
            
        self._alpha = alpha
        self._color_map = cmap
        self._color_limit = clim

    def get_color_map(self):
        return self._color_map

    @property
    def colour_map(self):
        return self._color_map

    def get_color_limit(self):
        return self._color_limit
    
    def _get_cdata(self):
        return self._im.ravel()[::20]

    @property
    def default_view(self):
        from traitsui.api import View, Item, Group, InstanceEditor, EnumEditor
        from PYME.ui.custom_traits_editors import HistLimitsEditor, CBEditor

        return View([Group([Item('dsname', label='Data', editor=EnumEditor(name='_datasource_choices')), ]),
                     #Item('method'),
                     Group([Item('clim', editor=HistLimitsEditor(data=self._get_cdata), show_label=False), ]),
                     Group([Item('cmap', label='LUT'),
                            Item('alpha', visible_when='method in ["flat", "tessel"]')
                            ])
                     ], )
        # buttons=['OK', 'Cancel'])

    def default_traits_view(self):
        return self.default_view
示例#23
0
class EnsembleFitProfiles(ModuleBase):
    inputName = Input('line_profiles')

    fit_type = CStr(list(profile_fitters.ensemble_fitters.keys())[0])
    ensemble_parameter_guess = Float(50.)
    hold_ensemble_parameter_constant = Bool(False)

    outputName = Output('fit_results')

    def execute(self, namespace):

        inp = namespace[self.inputName]

        # generate LineProfileHandler from tables
        handler = LineProfileHandler()
        handler._load_profiles_from_list(inp)

        fit_class = profile_fitters.ensemble_fitters[self.fit_type]
        self.fitter = fit_class(handler)

        if self.hold_ensemble_parameter_constant:
            self.fitter.fit_profiles(self.ensemble_parameter_guess)
        else:
            self.fitter.ensemble_fit(self.ensemble_parameter_guess)

        res = tabular.RecArraySource(self.fitter.results)

        # propagate metadata, if present
        res.mdh = MetaDataHandler.NestedClassMDHandler(
            getattr(inp, 'mdh', None))

        res.mdh['EnsembleFitProfiles.FitType'] = self.fit_type
        res.mdh[
            'EnsembleFitProfiles.EnsembleParameterGuess'] = self.ensemble_parameter_guess
        res.mdh[
            'EnsembleFitProfiles.HoldEnsembleParamConstant'] = self.hold_ensemble_parameter_constant

        namespace[self.outputName] = res

    @property
    def _fitter_choices(self):
        return list(profile_fitters.ensemble_fitters.keys())

    @property
    def default_view(self):
        from traitsui.api import View, Item
        from PYME.ui.custom_traits_editors import CBEditor

        return View(Item('inputName',
                         editor=CBEditor(choices=self._namespace_keys)),
                    Item('_'),
                    Item('fit_type',
                         editor=CBEditor(choices=self._fitter_choices)),
                    Item('_'),
                    Item('ensemble_parameter_guess'),
                    Item('_'),
                    Item('hold_ensemble_parameter_constant'),
                    Item('_'),
                    Item('outputName'),
                    buttons=['OK'])

    @property
    def pipeline_view(self):
        from traitsui.api import View, Item
        from PYME.ui.custom_traits_editors import CBEditor

        return View(
            Item('fit_type', editor=CBEditor(choices=self._fitter_choices)),
            Item('_'),
            Item('ensemble_parameter_guess'),
            Item('_'),
            Item('hold_ensemble_parameter_constant'),
        )

    @property
    def dsview_view(self):
        from traitsui.api import View, Item
        from PYME.ui.custom_traits_editors import CBEditor

        return View(Item('fit_type',
                         editor=CBEditor(choices=self._fitter_choices)),
                    Item('_'),
                    Item('ensemble_parameter_guess'),
                    Item('_'),
                    Item('hold_ensemble_parameter_constant'),
                    buttons=['OK'])
class PreprocessingFilter(CacheCleanupModule):
    """
    Optional. Combines a few image processing operations. Only 3D.
    
    1. Applies median filter for denoising.
    
    2. Replaces out of range values to defined values.
    
    3. Applies 2D Tukey filter to dampen potential edge artifacts.
        
    Inputs
    ------
    input_name : ImageStack
    
    Outputs
    -------
    output_name : ImageStack
    
    Parameters
    ----------
    median_filter_size : int
        Median filter size (``scipy.ndimage.median_filter``).
    threshold_lower : Float
        Pixels at this value or lower are replaced.
    clip_to_lower : Float
        Pixels below the lower threshold are replaced by this value.
    threshold_upper : Float
        Pixels at this value or higher are replaced.
    clip_to_upper : Float
        Pixels above the upper threshold are replaced by this value.
    tukey_size : float
        Shape parameter for Tukey filter (``scipy.signal.tukey``).
    cache_clip : File
        Use file as disk cache if provided.
    """
    input_name = Input('input')
    threshold_lower = Float(0)
    clip_to_lower = Float(0)
    threshold_upper = Float(65535)
    clip_to_upper = Float(0)
    median_filter_size = Int(3)
    tukey_size = Float(0.25)
    cache_clip = File("clip_cache.bin")
    output_name = Output('clipped_images')

    def _execute(self, namespace):
        self._start_time = time.time()
        ims = namespace[self.input_name]

        dtype = ims.data[:, :, 0].dtype

        # Somewhat arbitrary way to decide on chunk size
        chunk_size = 100000000 / ims.data.shape[0] / ims.data.shape[
            1] / dtype.itemsize
        chunk_size = max(1, chunk_size)
        chunk_size = int(chunk_size)
        #        print chunk_size

        tukey_mask_x = signal.tukey(ims.data.shape[0], self.tukey_size)
        tukey_mask_y = signal.tukey(ims.data.shape[1], self.tukey_size)
        self._tukey_mask_2d = np.multiply(
            *np.meshgrid(tukey_mask_x, tukey_mask_y, indexing='ij'))[:, :,
                                                                     None]

        if self.cache_clip == "":
            raw_data = np.empty(tuple(
                np.asarray(ims.data.shape[:3], dtype=np.long)),
                                dtype=dtype)
        else:
            raw_data = np.memmap(self.cache_clip,
                                 dtype=dtype,
                                 mode='w+',
                                 shape=tuple(
                                     np.asarray(ims.data.shape[:3],
                                                dtype=np.long)))

        progress = 0.2 * ims.data.shape[2]
        for f in np.arange(0, ims.data.shape[2], chunk_size):
            raw_data[:, :, f:f + chunk_size] = self.applyFilter(
                ims.data[:, :, f:f + chunk_size])

            if (f + chunk_size >= progress):
                if isinstance(raw_data, np.memmap):
                    raw_data.flush()
                progress += 0.2 * ims.data.shape[2]
                print("{:.2f} s. Completed clipping {} of {} total images.".
                      format(time.time() - self._start_time,
                             min(f + chunk_size, ims.data.shape[2]),
                             ims.data.shape[2]))

        clipped_images = ImageStack(raw_data, mdh=ims.mdh)
        self.completeMetadata(clipped_images)

        namespace[self.output_name] = clipped_images

    def applyFilter(self, data):
        """
            Performs the actual filtering here.
        """
        if self.median_filter_size > 0:
            data = ndimage.median_filter(data,
                                         self.median_filter_size,
                                         mode='nearest')
        data[data >= self.threshold_upper] = self.clip_to_upper
        data[data <= self.threshold_lower] = self.clip_to_lower
        data -= self.clip_to_lower
        if self.tukey_size > 0:
            data = data * self._tukey_mask_2d
        return data

    def completeMetadata(self, im):
        im.mdh['Processing.Clipping.LowerBounds'] = self.threshold_lower
        im.mdh['Processing.Clipping.LowerSetValue'] = self.clip_to_lower
        im.mdh['Processing.Clipping.UpperBounds'] = self.threshold_upper
        im.mdh['Processing.Clipping.UpperSetValue'] = self.clip_to_upper
        im.mdh['Processing.Tukey.Size'] = self.tukey_size
class RGBImageUpload(ImageUpload):
    """
    Create RGB (png) image and upload it to an OMERO server, optionally
    attaching localization files.

    Parameters
    ----------
    input_image : str
        name of image in the recipe namespace to upload
    input_localization_attachments : dict
        maps tabular types (keys) to attachment filenames (values). Tabular's
        will be saved as '.hdf' files and attached to the image
    filePattern : str
        pattern to determine name of image on OMERO server
    omero_dataset : str
        name of OMERO dataset to add the image to. If the dataset does not
        already exist it will be created.
    omero_project : str
        name of OMERO project to link the dataset to. If the project does not
        already exist it will be created.
    zoom : float
        how large to zoom the image
    scaling : str 
        how to scale the intensity - one of 'min-max' or 'percentile'
    scaling_factor: float
        `percentile` scaling only - which percentile to use
    colorblind_friendly : bool
        Use cyan, magenta, and yellow rather than RGB. True, by default.
    
    Notes
    -----
    OMERO server address and user login information must be stored in the user
    PYME config directory under plugins/config/pyme-omero, e.g. 
    /Users/Andrew/.PYME/plugins/config/pyme-omero. The file should be yaml
    formatted with the following keys:
        user
        password
        address
        port [optional]
    
    The project/image/dataset will be owned by the user set in the yaml file.
    """
    filePattern = '{file_stub}.png'
    scaling = Enum(['min-max', 'percentile'])
    scaling_factor = Float(0.95)
    zoom = Int(1)
    colorblind_friendly = Bool(True)

    def _save(self, image, path):
        from PIL import Image
        from PYME.IO.rgb_image import image_to_rgb, image_to_cmy

        if (self.colorblind_friendly and (image.data.shape[3] != 1)):
            im = image_to_cmy(image,
                              zoom=self.zoom,
                              scaling=self.scaling,
                              scaling_factor=self.scaling_factor)
        else:
            im = image_to_rgb(image,
                              zoom=self.zoom,
                              scaling=self.scaling,
                              scaling_factor=self.scaling_factor)

        rgb = Image.fromarray(im, mode='RGB')
        rgb.save(path)
class RCCDriftCorrection(RCCDriftCorrectionBase):
    """
    For localization data.
    
    Performs drift correction using cross-correlation, including redundant RCC from
    Wang et al. Optics Express 2014 22:13 (Bo Huang's RCC algorithm).
    
    Runtime will vary hugely depending on size of dataset and settings.
    
    ``cache_fft`` is necessary for large datasets.
        
    Inputs
    ------
    input_for_correction : Tabular
        Dataset used to calculate drift.
    input_for_mapping : Tabular
        *Deprecated.*   Dataset to correct.
    
    Outputs
    -------
    output_drift : Tuple of arrays
        Drift results.
    output_drift_plot : Plot
        *Deprecated.*   Plot of drift results.
    output_cross_cor : ImageStack
        Cross correlation images if ``debug_cor_file`` is not blank.
    outputName : Tabular
        *Deprecated.*   Drift-corrected dataset.
    
    Parameters
    ----------
    step : Int
        Setting for image construction. Step size between images
    window : Int
        Setting for image construction. Number of frames used per image. Should be equal or larger than step size.
    binsize : Float
        Setting for image construction. Pixel size.
    flatten_z : Bool
        Setting for image construction. Ignore z information if enabled.
    tukey_size : Float
        Setting for image construction. Shape parameter for Tukey filter (``scipy.signal.tukey``).
    cache_fft : File
        Use file as disk cache if provided.
    method : String
        Redundant, mean, or direct cross-correlation.
    shift_max : Float
        Rejection threshold for RCC.
    corr_window : Float
        Size of correlation window. Frames are only compared if within this frame range. N/A for DCC.
    multiprocessing : Float
        Enables multiprocessing.
    debug_cor_file : File
        Enables debugging. Use file as disk cache if provided.
    """
    
    input_for_correction = Input('Localizations')
    input_for_mapping = Input('Localizations')
    # redundant cross-corelation, mean cross-correlation, direction cross-correlation
    step = Int(2500)
    window = Int(2500)
    binsize = Float(30)
    flatten_z = Bool()
    tukey_size = Float(0.25)

    outputName = Output('corrected_localizations')
    
    def calc_corr_drift_from_locs(self, x, y, z, t):

        # bin edges for histogram
        bx = np.arange(x.min(), x.max() + self.binsize + 1, self.binsize)
        by = np.arange(y.min(), y.max() + self.binsize + 1, self.binsize)
        bz = np.arange(z.min(), z.max() + self.binsize + 1, self.binsize)
        
        # pad bin length to odd number so image size is even
        if bx.shape[0] % 2 == 0:
            bx = np.concatenate([bx, [bx[-1] + bx[1] - bx[0]]])
        if by.shape[0] % 2 == 0:
            by = np.concatenate([by, [by[-1] + by[1] - by[0]]])
        if bz.shape[0] > 2 and bz.shape[0] % 2 == 0:
            bz = np.concatenate([bz, [bz[-1] + bz[1] - bz[0]]])
        assert (bx.shape[0] % 2 == 1) and (by.shape[0] % 2 == 1), "Ops. Image not correctly padded to even size."

        # start time of all windows, allow partial window near end of pipeline
        time_values = np.arange(t.min(), t.max() + 1, self.step)
        # 2d array, start and end time of windows
        time_values = np.stack([time_values, np.clip(time_values + self.window, None, t.max())], axis=1)        
        n_steps = time_values.shape[0]
        # center time of center for returning. last window may have different spacing
        time_values_mid = time_values.mean(axis=1)

        if (np.any(np.diff(t) < 0)): # in case pipeline is not sorted for whatever reason
            t_sort_arg = np.argsort(t)
            t = t[t_sort_arg]
            x = x[t_sort_arg]
            y = y[t_sort_arg]
            z = z[t_sort_arg]
            
        time_indexes = np.zeros_like(time_values, dtype=int)
        time_indexes[:, 0] = np.searchsorted(t, time_values[:, 0], side='left')
        time_indexes[:, 1] = np.searchsorted(t, time_values[:, 1]-1, side='right')
        
#        print('time indexes')
#        print(time_values)
#        print(time_values_mid)
#        print(time_indexes)

        # Fourier transformed (and binned) set of images to correlate against
        # one another
        # Crude way of swaping longest axis to the last for optimizing rfft performance.
        # Code changed for this is limited to this method.
        xyz = np.asarray([x, y, z])
        bxyz = np.asarray([bx, by, bz])
        dims_order = np.arange(len(xyz))
        dims_length = np.asarray([len(b) for b in bxyz])
        dims_largest_index = np.argmax(dims_length)
        dims_order[-1], dims_order[dims_largest_index] = dims_order[dims_largest_index], dims_order[-1]
        xyz = xyz[dims_order]
        bxyz = bxyz[dims_order]
        dims_length = dims_length[dims_order]
        
        # use memmap for caching if ft_cache is defined
        if self.cache_fft == "":
            ft_images = np.zeros((n_steps, dims_length[0]-1, dims_length[1]-1, (dims_length[2]-1)//2 + 1, ), dtype=np.complex)
        else:
            ft_images = np.memmap(self.cache_fft, dtype=np.complex, mode='w+', shape=(n_steps, dims_length[0]-1, dims_length[1]-1, (dims_length[2]-1)//2 + 1, ))
        
        print(ft_images.shape)
        print("{:,} bytes".format(ft_images.nbytes))
        
        print("{:.2f} s. About to start heavy lifting.".format(time.time() - self._start_time))
        
        # fill ft_images
        # if multiprocessing, can either use or not caching
        # if not multiprocessing, don't pass filenames for caching, just the memmap array is fine
        if self.multiprocessing:
            dt = ft_images.dtype
            sh = ft_images.shape
            args = [(i, xyz[:,slice(*ti)].T, bxyz, (self.cache_fft, dt, sh, i), self.tukey_size) for i, ti in enumerate(time_indexes)]

            for i, (j, res) in enumerate(self._pool.imap_unordered(calc_fft_from_locs_helper, args)):                
                if self.cache_fft == "":
                    ft_images[j] = res
                 
                if ((i+1) % (n_steps//5) == 0):
                    print("{:.2f} s. Completed calculating {} of {} total ft images.".format(time.time() - self._start_time, i+1, n_steps))

        else:
            # For each window we wish to correlate...
            for i, ti in enumerate(time_indexes):
    
                # .. we generate an image and store ft of image
                t_slice = slice(*ti)
                ft_images[i] = calc_fft_from_locs(xyz[:,t_slice].T, bxyz, filter_size=self.tukey_size)
                
                if ((i+1) % (n_steps//5) == 0):
                    print("{:.2f} s. Completed calculating {} of {} total ft images.".format(time.time() - self._start_time, i+1, n_steps))
        
        print("{:.2f} s. Finished generating ft array.".format(time.time() - self._start_time))
        print("{:,} bytes".format(ft_images.nbytes))
        
        shifts, coefs = self.calc_corr_drift_from_ft_images(ft_images)
        
        # clean up of ft_images, potentially really large array
        if isinstance(ft_images, np.memmap):
            ft_images.flush()
        del ft_images
        
#        print(shifts)
#        print(coefs)
        return time_values_mid, self.binsize * shifts[:, dims_order], coefs

    def _execute(self, namespace):
#        from PYME.util import mProfile
        
#        self._start_time = time.time()
        self.trait_setq(**{"_start_time": time.time()})
        print("Starting drift correction module.")
        
        if self.multiprocessing:
            proccess_count = np.clip(multiprocessing.cpu_count()-1, 1, None)            
            self.trait_setq(**{"_pool": multiprocessing.Pool(processes=proccess_count)})
        
        locs = namespace[self.input_for_correction]

#        mProfile.profileOn(['localisations.py', 'processing.py'])
        drift_res = self.calc_corr_drift_from_locs(locs['x'], locs['y'], locs['z'] * (0 if self.flatten_z else 1), locs['t'])
        t_shift, shifts = self.rcc(self.shift_max,  *drift_res)
        
#        mProfile.profileOff()
#        mProfile.report()

        if self.multiprocessing:
            self._pool.close()
            self._pool.join()
            
        # convert frame-to-frame drift to drift from origin
        shifts = np.cumsum(shifts, 0)

        out = tabular.mappingFilter(namespace[self.input_for_mapping])
        t_out = out['t']
        # cubic interpolate with no smoothing
        dx = interpolate.CubicSpline(t_shift, shifts[:, 0])(t_out)
        dy = interpolate.CubicSpline(t_shift, shifts[:, 1])(t_out)
        dz = interpolate.CubicSpline(t_shift, shifts[:, 2])(t_out)

        if 'dx' in out.keys():
            # getting around oddity with mappingFilter
            # addColumn adds a new column but also keeps the old column
            # __getitem__ returns the new column
            # but mappings usues the old column
            # Wrap with another level of mappingFilter so the new column becomes the 'old column'
            out.addColumn('dx', dx)
            out.addColumn('dy', dy)
            out.addColumn('dz', dz)
            out = tabular.mappingFilter(out)
#            out.mdh = namespace[self.input_localizations].mdh
            out.setMapping('x', 'x + dx')
            out.setMapping('y', 'y + dy')
            out.setMapping('z', 'z + dz')
        else:
            out.addColumn('dx', dx)
            out.addColumn('dy', dy)
            out.addColumn('dz', dz)
            out.setMapping('x', 'x + dx')
            out.setMapping('y', 'y + dy')
            out.setMapping('z', 'z + dz')

        # propagate metadata, if present
        try:
            out.mdh = locs.mdh
        except AttributeError:
            pass

        namespace[self.outputName] = out
        namespace[self.output_drift] = t_shift, shifts
        
        # non essential, only for plotting out drift data
        namespace[self.output_drift_plot] = Plot(partial(generate_drift_plot, t_shift, shifts))
        
        namespace[self.output_cross_cor] = self._cc_image
class RCCDriftCorrectionBase(CacheCleanupModule):
    """    
    Performs drift correction using redundant cross-correlation from
    Wang et al. Optics Express 2014 22:13 (Bo Huang's RCC algorithm).
    Base class for other RCC recipes.
    Can take cached fft input (as filename, not an 'input').
    Only output drift as tuple of time points, and drift amount.
    Currently not registered by itself since not very usefule.
    """

    cache_fft = File("rcc_cache.bin")
    method = Enum(['RCC', 'MCC', 'DCC'])
    # redundant cross-corelation, mean cross-correlation, direct cross-correlation
    shift_max = Float(5)  # nm
    corr_window = Int(5)
    multiprocessing = Bool()
    debug_cor_file = File()

    output_drift = Output('drift')
    output_drift_plot = Output('drift_plot')

    # if debug_cor_file not blank, filled with imagestack of cross correlation
    output_cross_cor = Output('cross_cor')

    def calc_corr_drift_from_ft_images(self, ft_images):
        n_steps = ft_images.shape[0]

        # Matrix equation coefficient matrix
        # Shape can be predetermined based on method
        if self.method == "DCC":
            coefs_size = n_steps - 1
        elif self.corr_window > 0:
            coefs_size = n_steps * self.corr_window - self.corr_window * (
                self.corr_window + 1) // 2
        else:
            coefs_size = n_steps * (n_steps - 1) // 2
        coefs = np.zeros((coefs_size, n_steps - 1))
        shifts = np.zeros((coefs_size, 3))

        counter = 0

        ft_1_cache = list()
        ft_2_cache = list()
        autocor_shift_cache = list()

        #        print self.debug_cor_file
        if not self.debug_cor_file == "":
            cc_file_shape = [
                shifts.shape[0], ft_images.shape[1], ft_images.shape[2],
                (ft_images.shape[3] - 1) * 2
            ]

            # flatten shortest dimension to reduce cross correlation to 2d images for easier debugging
            min_arg = min(enumerate(cc_file_shape[1:]),
                          key=lambda x: x[1])[0] + 1
            cc_file_shape.pop(min_arg)
            cc_file_args = (self.debug_cor_file, np.float,
                            tuple(cc_file_shape))
            cc_file = np.memmap(cc_file_args[0],
                                dtype=cc_file_args[1],
                                mode="w+",
                                shape=cc_file_args[2])
            #            del cc_file
            cc_args = zip(range(shifts.shape[0]),
                          (cc_file_args, ) * shifts.shape[0])
        else:
            cc_args = (None, ) * shifts.shape[0]

        # For each ft image, calculate correlation
        for i in np.arange(0, n_steps - 1):
            if self.method == "DCC" and i > 0:
                break

            ft_1 = ft_images[i, :, :]

            autocor_shift = calc_shift(ft_1, ft_1)

            for j in np.arange(i + 1, n_steps):
                if (self.method != "DCC") and (self.corr_window > 0) and (
                        j - i > self.corr_window):
                    break

                ft_2 = ft_images[j, :, :]

                coefs[counter, i:j] = 1

                # if multiprocessing, use cache when defined
                if self.multiprocessing:
                    # if reading ft_images from cache, replace ft_1 and ft_2 with their indices
                    if not self.cache_fft == "":
                        ft_1 = i
                        ft_2 = j

                    ft_1_cache.append(ft_1)
                    ft_2_cache.append(ft_2)
                    autocor_shift_cache.append(autocor_shift)
                else:
                    shifts[counter, :] = calc_shift(ft_1, ft_2, autocor_shift,
                                                    None, cc_args[counter])

                    if ((counter + 1) % max(coefs_size // 5, 1) == 0):
                        print(
                            "{:.2f} s. Completed calculating {} of {} total shifts."
                            .format(time.time() - self._start_time,
                                    counter + 1, coefs_size))

                counter += 1

        if self.multiprocessing:
            args = zip(
                range(len(autocor_shift_cache)), ft_1_cache, ft_2_cache,
                autocor_shift_cache,
                len(ft_1_cache) *
                ((self.cache_fft, ft_images.dtype, ft_images.shape), ),
                cc_args)
            for i, (j, res) in enumerate(
                    self._pool.imap_unordered(calc_shift_helper, args)):
                shifts[j, ] = res

                if ((i + 1) % max(coefs_size // 5, 1) == 0):
                    print(
                        "{:.2f} s. Completed calculating {} of {} total shifts."
                        .format(time.time() - self._start_time, i + 1,
                                coefs_size))

        print("{:.2f} s. Finished calculating all shifts.".format(
            time.time() - self._start_time))
        print("{:,} bytes".format(coefs.nbytes))
        print("{:,} bytes".format(shifts.nbytes))

        if not self.debug_cor_file == "":
            # move time axis for ImageStack
            cc_file = np.moveaxis(cc_file, 0, 2)
            self.trait_setq(**{"_cc_image": ImageStack(data=cc_file.copy())})
            del cc_file
        else:
            self.trait_setq(**{"_cc_image": None})

        assert (np.all(np.any(
            coefs, axis=1))), "Coefficient matrix filled less than expected."

        mask = np.where(~np.isnan(shifts).any(axis=1))[0]
        if len(mask) < shifts.shape[0]:
            print("Removed {} cross correlations due to bad/missing data?".
                  format(shifts.shape[0] - len(mask)))
            coefs = coefs[mask, :]
            shifts = shifts[mask, :]

        assert (coefs.shape[0] > 0) and (
            np.linalg.matrix_rank(coefs) == n_steps -
            1), "Something went wrong with coefficient matrix. Not full rank."

        return shifts, coefs  # shifts.shape[0] is n_steps - 1

    def rcc(
        self,
        shift_max,
        t_shift,
        shifts,
        coefs,
    ):
        """
            Should probably rename function.
            Takes cross correlation results and calculates shifts.
        """

        print("{:.2f} s. About to start solving shifts array.".format(
            time.time() - self._start_time))

        # Estimate drift
        drifts = np.matmul(np.linalg.pinv(coefs), shifts)
        #        print(t_shift)
        #        print(drifts)

        print("{:.2f} s. Done solving shifts array.".format(time.time() -
                                                            self._start_time))

        if self.method == "RCC":

            # Calculate residual errors
            residuals = np.matmul(coefs, drifts) - shifts
            residuals_dist = np.linalg.norm(residuals, axis=1)

            # Sort and mask residual errors
            residuals_arg = np.argsort(-residuals_dist)
            residuals_arg = residuals_arg[
                residuals_dist[residuals_arg] > shift_max]

            # Remove coefs rows
            # Descending from largest residuals to small
            # Only if matrix remains full rank
            coefs_temp = np.empty_like(coefs)
            counter = 0
            for i, index in enumerate(residuals_arg):
                coefs_temp[:] = coefs
                coefs_temp[index, :] = 0
                if np.linalg.matrix_rank(coefs_temp) == coefs.shape[1]:
                    coefs[:] = coefs_temp
                    #                print("index {} with residual of {} removed".format(index, residuals_dist[index]))
                    counter += 1
                else:
                    print(
                        "Could not remove all residuals over shift_max threshold."
                    )
                    break
            print("removed {} in total".format(counter))

            # Estimate drift again
            drifts = np.matmul(np.linalg.pinv(coefs), shifts)

            print("{:.2f} s. RCC completed. Repeated solving shifts array.".
                  format(time.time() - self._start_time))

        # pad with 0 drift for first time point
        drifts = np.pad(drifts, [[1, 0], [0, 0]],
                        'constant',
                        constant_values=0)

        return t_shift, drifts

    def _execute(self, namespace):
        # dervied versions of RCC need to override this method
        # 'execute' of this RCC base class is not throughly tested as its use is probably quite limited.

        #        from PYME.util import mProfile

        self._start_time = time.time()
        print("Starting drift correction module.")

        if self.multiprocessing:
            proccess_count = np.clip(multiprocessing.cpu_count() - 1, 1, None)
            self._pool = multiprocessing.Pool(processes=proccess_count)


#        mProfile.profileOn(['localisations.py'])

        drift_res = self.calc_corr_drift_from_ft_images(self.cache_fft)
        t_shift, shifts = self.rcc(self.shift_max, *drift_res)
        #        mProfile.profileOff()
        #        mProfile.report()

        if self.multiprocessing:
            self._pool.close()
            self._pool.join()

        # convert frame-to-frame drift to drift from origin
        shifts = np.cumsum(shifts, 0)

        namespace[self.output_drift] = t_shift, shifts
示例#28
0
class OctreeRenderLayer(TriangleRenderLayer):
    """
    Layer for viewing octrees. Takes in an octree, splits the faces into triangles, and then uses the rendering engines
    from PYME.LMVis.layers.triangle_mesh.
    """
    # properties to show in the GUI. Note that we also inherit 'visible' from BaseLayer
    #vertexColour = CStr('', desc='Name of variable used to colour our points')
    #cmap = Enum(*cm.cmapnames, default='gist_rainbow', desc='Name of colourmap used to colour faces')
    #clim = ListFloat([0, 1], desc='How our variable should be scaled prior to colour mapping')
    #alpha = Float(1.0, desc='Face tranparency')
    depth = Int(
        3,
        desc=
        'Depth at which to render Octree. Set to -1 for dynamic depth rendering.'
    )
    density = Float(0.0, desc='Minimum density of octree node to display.')
    min_points = Int(10, desc='Number of points/node to truncate octree at')

    #method = Enum(*ENGINES.keys(), desc='Method used to display faces')

    def __init__(self, pipeline, method='wireframe', dsname='', **kwargs):
        TriangleRenderLayer.__init__(self, pipeline, method, dsname, **kwargs)

        self.on_trait_change(self.update, 'depth')
        self.on_trait_change(self.update, 'density')
        self.on_trait_change(self.update, 'min_points')

    @property
    def _ds_class(self):
        from PYME.experimental import octree
        return (octree.Octree, octree.PyOctree)

    def update_from_datasource(self, ds):
        """
        Opens an octree. Subdivides the faces into triangles. Feeds the triangle points/normals to update_data.

        Parameters
        ----------
        ds :
            Octree (see PYME.experimental.octree)

        Returns
        -------
        None
        """
        nodes = ds._nodes[ds._nodes[ds._nodes['parent']]['nPoints'] >= float(
            self.min_points)]

        if self.depth > 0:
            # Grab the nodes at the specified depth
            nodes = nodes[nodes['depth'] == self.depth]
            box_sizes = np.ones((nodes.shape[0], 3)) * ds.box_size(self.depth)
            node_density = 1. * nodes['nPoints'] / np.prod(box_sizes, axis=1)
            nodes = nodes[node_density >= self.density]
            box_sizes = np.ones((nodes.shape[0], 3)) * ds.box_size(self.depth)
            alpha = nodes['nPoints'] / box_sizes[:, 0]
        elif self.depth == 0:
            # plot all bins
            nodes = nodes[nodes['nPoints'] >= 1]
            box_sizes = np.vstack(ds.box_size(nodes['depth'])).T

            alpha = nodes['nPoints'] * ((2**nodes['depth'])**3)

        else:
            # Follow the nodes until we reach a terminating node, then append this node to our list of nodes to render
            # Start at the 0th node
            # children = ds._nodes[0]['children'].tolist()
            # node_indices = []
            # box_sizes = []
            # # Do this until we've looked at the whole octree (the list of children is empty)
            # while children:
            #     # Check the child
            #     node_index = children.pop()
            #     curr_node = ds._nodes[node_index]
            #     # Is this a terminating node?
            #     if np.any(curr_node['children']):
            #         # It's not, so we'll add the children to the list
            #         new_children = curr_node['children'][curr_node['children'] > 0]
            #         children.extend(new_children)
            #     else:
            #         # Terminating node! We want to render this
            #         node_indices.append(node_index)
            #         box_sizes.append(ds.box_size(curr_node['depth']))

            # We've followed the octree to the end, return the nodes and box sizes
            #nodes = ds._nodes[node_indices]
            #box_sizes = np.array(box_sizes)

            nodes = nodes[np.sum(nodes['children']) == 0]
            box_sizes = np.vstack(ds.box_size(nodes['depth'])).T

            alpha = nodes['nPoints'] * ((2.0**nodes['depth']))**3

        # First we need the vertices of the cube. We find them from the center c provided and the box size (lx, ly, lz)
        # provided by the octree:

        if len(nodes) > 0:
            c = nodes['centre']  # center
            v0 = c + box_sizes * -1 / 2  # v0 = c - lx/2 - ly/2 - lz/2
            v1 = c + box_sizes * [-1, -1, 1] / 2  # v1 = c - lx/2 - ly/2 + lz/2
            v2 = c + box_sizes * [-1, 1, -1] / 2  # v2 = c - lx/2 + ly/2 - lz/2
            v3 = c + box_sizes * [-1, 1, 1] / 2  # v3 = c - lx/2 + ly/2 + lz/2
            v4 = c + box_sizes * [1, -1, -1] / 2  # v4 = c + lx/2 - ly/2 - lz/2
            v5 = c + box_sizes * [1, -1, 1] / 2  # v5 = c + lx/2 - ly/2 + lz/2
            v6 = c + box_sizes * [1, 1, -1] / 2  # v6 = c + lx/2 + ly/2 - lz/2
            v7 = c + box_sizes / 2  # v7 = c + lx/2 + ly/2 + lz/2

            #
            #     z
            #     ^
            #     |
            #    v1 ----------v3
            #    /|           /|
            #   / |          / |
            #  v5----------v7  |
            #  |  |    c    |  |
            #  | v0---------|-v2
            #  | /          | /
            #  v4-----------v6---> y
            #  /
            # x
            #
            # Now note that the counterclockwise triangles (when viewed straight-on) formed along the faces of the cube are:
            #
            # v0 v2 v6
            # v0 v4 v5
            # v1 v3 v2
            # v2 v0 v1
            # v3 v1 v5
            # v3 v7 v6
            # v4 v6 v7
            # v5 v1 v0
            # v5 v7 v3
            # v6 v2 v3
            # v6 v4 v0
            # v7 v5 v4

            # Concatenate vertices, interleave, restore to 3x(3N) points (3xN triangles),
            # and assign the points to x, y, z vectors
            triangle_v0 = np.vstack(
                (v0, v0, v1, v2, v3, v3, v4, v5, v5, v6, v6, v7))
            triangle_v1 = np.vstack(
                (v2, v4, v3, v0, v1, v7, v6, v1, v7, v2, v4, v5))
            triangle_v2 = np.vstack(
                (v6, v5, v2, v1, v5, v6, v7, v0, v3, v3, v0, v4))

            x, y, z = np.hstack(
                (triangle_v0, triangle_v1, triangle_v2)).reshape(-1, 3).T

            # Now we create the normal as the cross product (triangle_v2 - triangle_v1) x (triangle_v0 - triangle_v1)
            triangle_normals = np.cross((triangle_v2 - triangle_v1),
                                        (triangle_v0 - triangle_v1))
            # We copy the normals 3 times per triangle to get 3x(3N) normals to match the vertices shape
            xn, yn, zn = np.repeat(triangle_normals.T, 3, axis=1)

            alpha = self.alpha * alpha / alpha.max()
            alpha = (alpha[None, :] * np.ones(12)[:, None])
            alpha = np.repeat(alpha.ravel(), 3)
            print('Octree scaled alpha range: %g, %g' %
                  (alpha.min(), alpha.max()))

            # Pass the restructured data to update_data
            self.update_data(x,
                             y,
                             z,
                             cmap=getattr(cm, self.cmap),
                             clim=self.clim,
                             alpha=alpha,
                             xn=xn,
                             yn=yn,
                             zn=zn)
        else:
            print('No nodes for density {0}, depth {1}'.format(
                self.density, self.depth))

    @property
    def default_view(self):
        from traitsui.api import View, Item, Group, InstanceEditor, EnumEditor
        from PYME.ui.custom_traits_editors import HistLimitsEditor, CBEditor

        return View(
            [  #Group([
                Item('dsname',
                     label='Data',
                     editor=EnumEditor(name='_datasource_choices')),
                Item('method'),
                Item('depth'),
                Item('min_points'),
                #Item('vertexColour', editor=EnumEditor(name='_datasource_keys'), label='Colour'),
                #Group([Item('clim', editor=HistLimitsEditor(data=self._get_cdata), show_label=False), ]),
                Group([Item('cmap', label='LUT'),
                       Item('alpha')])
            ], )
class LayerWrapper(HasTraits):
    cmap = Enum(*cm.cmapnames, default='gist_rainbow')
    clim = ListFloat([0, 1])
    alpha = Float(1.0)
    visible = Bool(True)
    method = Enum(*ENGINES.keys())
    engine = Instance(layers.BaseLayer)
    dsname = CStr('output')

    def __init__(self,
                 pipeline,
                 method='points',
                 ds_name='',
                 cmap='gist_rainbow',
                 clim=[0, 1],
                 alpha=1.0,
                 visible=True,
                 method_args={}):
        self._pipeline = pipeline
        #self._namespace=getattr(pipeline, 'namespace', {})
        #self.dsname = None
        self.engine = None

        self.cmap = cmap
        self.clim = clim
        self.alpha = alpha

        self.visible = visible

        self.on_update = dispatch.Signal()

        self.on_trait_change(lambda: self.on_update.send(self), 'visible')
        self.on_trait_change(self.update, 'cmap, clim, alpha, dsname')
        self.on_trait_change(self._set_method, 'method')

        #self.set_datasource(ds_name)
        self.dsname = ds_name

        self._eng_params = dict(method_args)
        self.method = method

        self._pipeline.onRebuild.connect(self.update)

    @property
    def _namespace(self):
        return self._pipeline.layer_datasources

    @property
    def bbox(self):
        return self.engine.bbox

    @property
    def colour_map(self):
        return self.engine.colour_map

    @property
    def data_source_names(self):
        names = []  #'']
        for k, v in self._namespace.items():
            names.append(k)
            if isinstance(v, tabular.ColourFilter):
                for c in v.getColourChans():
                    names.append('.'.join([k, c]))

        return names

    @property
    def datasource(self):
        if self.dsname == '':
            return self._pipeline

        parts = self.dsname.split('.')
        if len(parts) == 2:
            # special case - permit access to channels using dot notation
            # NB: only works if our underlying datasource is a ColourFilter
            ds, channel = parts
            return self._namespace.get(ds, None).get_channel_ds(channel)
        else:
            return self._namespace.get(self.dsname, None)

    def _set_method(self):
        if self.engine:
            self._eng_params = self.engine.get('point_size', 'vertexColour')
            #print(eng_params)

        self.engine = ENGINES[self.method](self._context)
        self.engine.set(**self._eng_params)
        self.engine.on_trait_change(self._update, 'vertexColour')
        self.engine.on_trait_change(self.update)

        self.update()

    # def set_datasource(self, ds_name):
    #     self._dsname = ds_name
    #
    #     self.update()

    def _update(self, *args, **kwargs):
        cdata = self._get_cdata()
        self.clim = [float(cdata.min()), float(cdata.max())]
        #self.update(*args, **kwargs)

    def update(self, *args, **kwargs):
        print('lw update')
        if not (self.engine is None or self.datasource is None):
            self.engine.update_from_datasource(self.datasource,
                                               getattr(cm, self.cmap),
                                               self.clim, self.alpha)
            self.on_update.send(self)

    def render(self, gl_canvas):
        if self.visible:
            self.engine.render(gl_canvas)

    def _get_cdata(self):
        try:
            cdata = self.datasource[self.engine.vertexColour]
        except KeyError:
            cdata = np.array([0, 1])

        return cdata

    @property
    def default_view(self):
        from traitsui.api import View, Item, Group, InstanceEditor, EnumEditor
        from PYME.ui.custom_traits_editors import HistLimitsEditor, CBEditor

        return View(
            [
                Group([
                    Item('dsname',
                         label='Data',
                         editor=EnumEditor(values=self.data_source_names)),
                ]),
                Item('method'),
                #Item('_'),
                Group([
                    Item('engine',
                         style='custom',
                         show_label=False,
                         editor=InstanceEditor(
                             view=self.engine.view(self.datasource.keys()))),
                ]),
                #Item('engine.color_key', editor=CBEditor(choices=self.datasource.keys())),
                Group([
                    Item('clim',
                         editor=HistLimitsEditor(data=self._get_cdata),
                         show_label=False),
                ]),
                Group([
                    Item('cmap', label='LUT'),
                    Item('alpha'),
                    Item('visible')
                ],
                      orientation='horizontal',
                      layout='flow')
            ], )
        #buttons=['OK', 'Cancel'])

    def default_traits_view(self):
        return self.default_view
示例#30
0
class Point3DRenderLayer(VertexRenderLayer):
    point_size = Float(30.0)

    def __init__(self,
                 x=None,
                 y=None,
                 z=None,
                 colors=None,
                 color_map=None,
                 color_limit=None,
                 alpha=1.0,
                 point_size=30):
        VertexRenderLayer.__init__(self, x, y, z, colors, color_map,
                                   color_limit, alpha)
        self.point_size = point_size
        self.set_shader_program(DefaultShaderProgram)

    def render(self, gl_canvas):
        self.shader_program.xmin, self.shader_program.xmax = gl_canvas.bounds[
            'x']
        self.shader_program.ymin, self.shader_program.ymax = gl_canvas.bounds[
            'y']
        self.shader_program.zmin, self.shader_program.zmax = gl_canvas.bounds[
            'z']
        self.shader_program.vmin, self.shader_program.vmax = gl_canvas.bounds[
            'v']

        with self.shader_program:

            n_vertices = self.get_vertices().shape[0]

            glVertexPointerf(self.get_vertices())
            glNormalPointerf(self.get_normals())
            glColorPointerf(self.get_colors())

            if gl_canvas:
                if self.point_size == 0:
                    glPointSize(1 / gl_canvas.pixelsize)
                else:
                    glPointSize(self.point_size / gl_canvas.pixelsize)
            else:
                glPointSize(self.point_size)
            glDrawArrays(GL_POINTS, 0, n_vertices)

    def get_point_size(self):
        warnings.warn("use the point_size property instead",
                      DeprecationWarning)
        return self.point_size

    def set_point_size(self, point_size):
        warnings.warn("use the point_size property instead",
                      DeprecationWarning)
        self.point_size = point_size

    def view(self, ds_keys):
        from traitsui.api import View, Item, Group
        from PYME.ui.custom_traits_editors import CBEditor

        return View([
            Item('vertexColour',
                 editor=CBEditor(choices=ds_keys),
                 label='Colour'),
            Item('point_size', label='Size [nm]')
        ])