Beispiel #1
0
class RuleChain(HasTraits):
    post_on = Enum(POST_CHOICES)
    protocol = CStr('')

    def __init__(self, rule_factories=None, *args, **kwargs):
        if rule_factories is None:
            rule_factories = list()
        self.rule_factories = rule_factories
        HasTraits.__init__(self, *args, **kwargs)
Beispiel #2
0
class ClusterStats(ModuleBase):

    inputName = Input('with_clumps')
    IDkey = CStr('clumpIndex')
    StatMethod = Enum(['std', 'min', 'max', 'mean', 'median', 'count', 'sum'])
    StatKey = CStr('x')
    outputName = Output('withClumpStats')

    def execute(self, namespace):
        from scipy.stats import binned_statistic

        inp = namespace[self.inputName]
        mapped = tabular.mappingFilter(inp)

        ids = inp[self.IDkey]  # I imagine this needs to be an int type key
        prop = inp[self.StatKey]
        maxid = int(ids.max())
        edges = -0.5 + np.arange(maxid + 2)
        resstat = binned_statistic(ids,
                                   prop,
                                   statistic=self.StatMethod,
                                   bins=edges)

        mapped.addColumn(self.StatKey + "_" + self.StatMethod, resstat[0][ids])

        # propogate metadata, if present
        try:
            mapped.mdh = inp.mdh
        except AttributeError:
            pass

        namespace[self.outputName] = mapped

    @property
    def _key_choices(self):
        #try and find the available column names
        try:
            return sorted(self._parent.namespace[self.inputName].keys())
        except:
            return []

    @property
    def default_view(self):
        from traitsui.api import View, Group, Item
        from PYME.ui.custom_traits_editors import CBEditor

        return View(Item('inputName',
                         editor=CBEditor(choices=self._namespace_keys)),
                    Item('_'),
                    Item('IDkey', editor=CBEditor(choices=self._key_choices)),
                    Item('StatKey',
                         editor=CBEditor(choices=self._key_choices)),
                    Item('StatMethod'),
                    Item('_'),
                    Item('outputName'),
                    buttons=['OK'])
Beispiel #3
0
class OutputModule(ModuleBase):
    """
    Output modules are the one exception to the recipe-module functional (no side effects) programming
    paradigm and are used to perform IO and save or display designated outputs/endpoints from the recipe.
    
    As such, they should act solely as a sink, and should not do any processing or write anything back
    to the namespace.
    """
    filePattern = CStr('{output_dir}/{file_stub}.csv')
    scheme = Enum('File', 'pyme-cluster://', 'pyme-cluster:// - aggregate')

    def _schemafy_filename(self, out_filename):
        if self.scheme == 'File':
            return out_filename
        elif self.scheme == 'pyme-cluster://':
            from PYME.IO import clusterIO
            import os
            return os.path.join(clusterIO.local_dataroot,
                                out_filename.lstrip('/'))
        elif self.scheme == 'pyme-cluster:// - aggregate':
            raise RuntimeError('Aggregation not suported')

    def _check_outputs(self):
        """
        This function exists to help with debugging when writing a new recipe module.

        Over-ridden here for the special case IO modules derived from OutputModule as these are permitted to
        have side-effects, but not permitted to have classical outputs to the namespace and will execute
        (when appropriate) regardless.
        """
        if len(self.outputs) != 0:
            raise RuntimeError(
                'Output modules should not write anything to the namespace')

    def generate(self, namespace, recipe_context={}):
        """
        Function to be called from within dh5view (rather than batch processing). Some outputs are ignored, in which
        case this function returns None.
        
        Parameters
        ----------
        namespace

        Returns
        -------

        """
        return None

    def execute(self, namespace):
        """
        Output modules be definition do nothing when executed - they act as a sink and implement a save method instead.

        """
        pass
Beispiel #4
0
class ImageModuleBase(ModuleBase):
    # NOTE - if a derived class only supports, e.g. XY analysis, it should redefine this trait ro only include the dimensions
    # it supports
    dimensionality = Enum(
        'XY',
        'XYZ',
        'XYZT',
        desc='Which image dimensions should the filter be applied to?')

    #processFramesIndividually = Bool(True)

    @property
    def processFramesIndividually(self):
        import warnings
        warnings.warn(
            'Use dimensionality =="XY" instead to check which dimensions a filter should be applied to, chunking '
            'hints for computational optimisation',
            DeprecationWarning,
            stacklevel=2)

        logger.warning(
            'Use dimensionality =="XY" instead to check which dimensions a filter should be applied to, chunking '
            'hints for computational optimisation')
        return self.dimensionality == 'XY'

    @processFramesIndividually.setter
    def processFramesIndividually(self, value):
        import warnings
        warnings.warn(
            'Use dimensionality ="XY" instead to check which dimensions a filter should be applied to, chunking '
            'hints for computational optimisation',
            DeprecationWarning,
            stacklevel=2)

        logger.warning(
            'Use dimensionality ="XY" instead to check which dimensions a filter should be applied to, chunking '
            'hints for computational optimisation')

        if value:
            self.dimensionality = 'XY'
        else:
            self.dimensionality = 'XYZ'
Beispiel #5
0
class LabelByRegionProperty(Filter):
    """Asigns a region property to each contiguous region in the input mask.
    Optionally throws away all regions for which property is outside a given range.
    """
    regionProperty = Enum(['area', 'circularity', 'aspectratio'])
    filterByProperty = Bool(False)
    propertyMin = Float(0)
    propertyMax = Float(1e6)

    def applyFilter(self, data, chanNum, frNum, im):
        mask = data > 0.5
        labs, nlabs = ndimage.label(mask)
        rp = skimage.measure.regionprops(labs, None, cache=True)

        m2 = np.zeros_like(mask, dtype='float')
        objs = ndimage.find_objects(labs)
        for region in rp:
            oslices = objs[region.label - 1]
            r = labs[oslices] == region.label
            #print r.shape
            if self.regionProperty == 'area':
                propValue = region.area
            elif self.regionProperty == 'aspectratio':
                propValue = region.major_axis_length / region.minor_axis_length
            elif self.regionProperty == 'circularity':
                propValue = 4 * math.pi * region.area / (region.perimeter *
                                                         region.perimeter)
            if self.filterByProperty:
                if (propValue >= self.propertyMin) and (propValue <=
                                                        self.propertyMax):
                    m2[oslices] += r * propValue
            else:
                m2[oslices] += r * propValue

        return m2

    def completeMetadata(self, im):
        im.mdh['Labelling.Property'] = self.regionProperty
Beispiel #6
0
class FlexiThreshold(Filter):
    """Chose a threshold using a range of available thresholding methods.
       Currently we can chose from: simple, fractional, otsu, isodata
    """
    method = Enum(
        'simple', 'fractional', 'otsu', 'isodata', 'li',
        'yen')  # newer skimage has minimum, mean and triangle as well
    parameter = Float(0.5)
    clipAt = Float(
        2e6
    )  # used to be 10 - increase to large value for newer PYME renderings

    def fractionalThreshold(self, data):
        N, bins = np.histogram(data, bins=5000)
        #calculate bin centres
        bin_mids = (bins[:-1])
        cN = np.cumsum(N * bin_mids)
        i = np.argmin(abs(cN - cN[-1] * (1 - self.parameter)))
        threshold = bins[i]
        return threshold

    def applyFilter(self, data, chanNum, frNum, im):

        if self.method == 'fractional':
            threshold = self.fractionalThreshold(
                np.clip(data, None, self.clipAt))
        elif self.method == 'simple':
            threshold = self.parameter
        else:
            method = getattr(skf, 'threshold_%s' % self.method)
            threshold = method(np.clip(data, None, self.clipAt))

        mask = data > threshold
        return mask

    def completeMetadata(self, im):
        im.mdh['Processing.ThresholdParameter'] = self.parameter
        im.mdh['Processing.ThresholdMethod'] = self.method
class PointCloudRenderLayer(EngineLayer):
    """
    A layer for viewing point-cloud data, using one of 3 engines (indicated above)
    
    """

    # properties to show in the GUI. Note that we also inherit 'visible' from BaseLayer
    vertexColour = CStr('', desc='Name of variable used to colour our points')
    point_size = Float(30.0, desc='Rendered size of the points in nm')
    cmap = Enum(*cm.cmapnames,
                default='gist_rainbow',
                desc='Name of colourmap used to colour points')
    clim = ListFloat(
        [0, 1],
        desc='How our variable should be scaled prior to colour mapping')
    alpha = Float(1.0, desc='Point tranparency')
    method = Enum(*ENGINES.keys(), desc='Method used to display points')
    dsname = CStr(
        'output',
        desc=
        'Name of the datasource within the pipeline to use as a source of points'
    )
    _datasource_keys = List()
    _datasource_choices = List()

    def __init__(self,
                 pipeline,
                 method='points',
                 dsname='',
                 context=None,
                 **kwargs):
        EngineLayer.__init__(self, context=context, **kwargs)
        self._pipeline = pipeline
        self.engine = None
        self.cmap = 'gist_rainbow'

        self.x_key = 'x'  #TODO - make these traits?
        self.y_key = 'y'
        self.z_key = 'z'

        self.xn_key = 'xn'
        self.yn_key = 'yn'
        self.zn_key = 'zn'

        self._bbox = None

        # define a signal so that people can be notified when we are updated (currently used to force a redraw when
        # parameters change)
        self.on_update = dispatch.Signal()

        # define responses to changes in various traits
        self.on_trait_change(self._update, 'vertexColour')
        self.on_trait_change(lambda: self.on_update.send(self), 'visible')
        self.on_trait_change(self.update,
                             'cmap, clim, alpha, dsname, point_size')
        self.on_trait_change(self._set_method, 'method')

        # update any of our traits which were passed as command line arguments
        self.set(**kwargs)

        # update datasource name and method
        #logger.debug('Setting dsname and method')
        self.dsname = dsname
        self.method = method

        self._set_method()

        # if we were given a pipeline, connect ourselves to the onRebuild signal so that we can automatically update
        # ourselves
        if not self._pipeline is None:
            self._pipeline.onRebuild.connect(self.update)

    @property
    def datasource(self):
        """
        Return the datasource we are connected to (through our dsname property).
        """
        return self._pipeline.get_layer_data(self.dsname)

    def _set_method(self):
        #logger.debug('Setting layer method to %s' % self.method)
        self.engine = ENGINES[self.method](self._context)
        self.update()

    def _get_cdata(self):
        try:
            cdata = self.datasource[self.vertexColour]
        except KeyError:
            cdata = np.array([0, 1])

        return cdata

    def _update(self, *args, **kwargs):
        cdata = self._get_cdata()
        self.clim = [float(cdata.min()), float(cdata.max())]
        #self.update(*args, **kwargs)

    def update(self, *args, **kwargs):
        print('lw update')
        self._datasource_choices = self._pipeline.layer_data_source_names
        if not self.datasource is None:
            self._datasource_keys = sorted(self.datasource.keys())

        if not (self.engine is None or self.datasource is None):
            self.update_from_datasource(self.datasource)
            self.on_update.send(self)

    @property
    def bbox(self):
        return self._bbox

    def update_from_datasource(self, ds):
        x, y = ds[self.x_key], ds[self.y_key]

        if not self.z_key is None:
            try:
                z = ds[self.z_key]
            except KeyError:
                z = 0 * x
        else:
            z = 0 * x

        if not self.vertexColour == '':
            c = ds[self.vertexColour]
        else:
            c = 0 * x

        if self.xn_key in ds.keys():
            xn, yn, zn = ds[self.xn_key], ds[self.yn_key], ds[self.zn_key]
            self.update_data(x,
                             y,
                             z,
                             c,
                             cmap=getattr(cm, self.cmap),
                             clim=self.clim,
                             alpha=self.alpha,
                             xn=xn,
                             yn=yn,
                             zn=zn)
        else:
            self.update_data(x,
                             y,
                             z,
                             c,
                             cmap=getattr(cm, self.cmap),
                             clim=self.clim,
                             alpha=self.alpha)

    def update_data(self,
                    x=None,
                    y=None,
                    z=None,
                    colors=None,
                    cmap=None,
                    clim=None,
                    alpha=1.0,
                    xn=None,
                    yn=None,
                    zn=None):
        self._vertices = None
        self._normals = None
        self._colors = None
        self._color_map = None
        self._color_limit = 0
        self._alpha = 0
        if x is not None and y is not None and z is not None:
            vertices = np.vstack((x.ravel(), y.ravel(), z.ravel()))
            vertices = vertices.T.ravel().reshape(len(x.ravel()), 3)

            if not xn is None:
                normals = np.vstack(
                    (xn.ravel(), yn.ravel(),
                     zn.ravel())).T.ravel().reshape(len(x.ravel()), 3)
            else:
                normals = -0.69 * np.ones(vertices.shape)

            self._bbox = np.array(
                [x.min(), y.min(),
                 z.min(), x.max(),
                 y.max(), z.max()])
        else:
            vertices = None
            normals = None
            self._bbox = None

        if clim is not None and colors is not None and clim is not None:
            cs_ = ((colors - clim[0]) / (clim[1] - clim[0]))
            cs = cmap(cs_)
            cs[:, 3] = alpha

            cs = cs.ravel().reshape(len(colors), 4)
        else:
            #cs = None
            if not vertices is None:
                cs = np.ones((vertices.shape[0], 4), 'f')
            else:
                cs = None
            color_map = None
            color_limit = None

        self.set_values(vertices, normals, cs, cmap, clim, alpha)

    def set_values(self,
                   vertices=None,
                   normals=None,
                   colors=None,
                   color_map=None,
                   color_limit=None,
                   alpha=None):
        if vertices is not None:
            self._vertices = vertices
        if normals is not None:
            self._normals = normals
        if color_map is not None:
            self._color_map = color_map
        if colors is not None:
            self._colors = colors
        if color_limit is not None:
            self._color_limit = color_limit
        if alpha is not None:
            self._alpha = alpha

    def get_vertices(self):
        return self._vertices

    def get_normals(self):
        return self._normals

    def get_colors(self):
        return self._colors

    def get_color_map(self):
        return self._color_map

    @property
    def colour_map(self):
        return self._color_map

    def get_color_limit(self):
        return self._color_limit

    @property
    def default_view(self):
        from traitsui.api import View, Item, Group, InstanceEditor, EnumEditor, TextEditor
        from PYME.ui.custom_traits_editors import HistLimitsEditor, CBEditor

        return View([
            Group([
                Item('dsname',
                     label='Data',
                     editor=EnumEditor(name='_datasource_choices')),
            ]),
            Item('method'),
            Item(
                'vertexColour',
                editor=EnumEditor(name='_datasource_keys'),
                label='Colour',
                visible_when='cmap not in ["R", "G", "B", "C", "M","Y", "K"]'),
            Group(
                [
                    Item('clim',
                         editor=HistLimitsEditor(data=self._get_cdata,
                                                 update_signal=self.on_update),
                         show_label=False),
                ],
                visible_when='cmap not in ["R", "G", "B", "C", "M","Y", "K"]'),
            Group(
                Item('cmap', label='LUT'),
                Item('alpha',
                     visible_when=
                     "method in ['pointsprites', 'transparent_points']",
                     editor=TextEditor(auto_set=False,
                                       enter_set=True,
                                       evaluate=float)),
                Item('point_size',
                     editor=TextEditor(auto_set=False,
                                       enter_set=True,
                                       evaluate=float)))
        ])
        #buttons=['OK', 'Cancel'])

    def default_traits_view(self):
        return self.default_view
class RGBImageUpload(ImageUpload):
    """
    Create RGB (png) image and upload it to an OMERO server, optionally
    attaching localization files.

    Parameters
    ----------
    input_image : str
        name of image in the recipe namespace to upload
    input_localization_attachments : dict
        maps tabular types (keys) to attachment filenames (values). Tabular's
        will be saved as '.hdf' files and attached to the image
    filePattern : str
        pattern to determine name of image on OMERO server
    omero_dataset : str
        name of OMERO dataset to add the image to. If the dataset does not
        already exist it will be created.
    omero_project : str
        name of OMERO project to link the dataset to. If the project does not
        already exist it will be created.
    zoom : float
        how large to zoom the image
    scaling : str 
        how to scale the intensity - one of 'min-max' or 'percentile'
    scaling_factor: float
        `percentile` scaling only - which percentile to use
    colorblind_friendly : bool
        Use cyan, magenta, and yellow rather than RGB. True, by default.
    
    Notes
    -----
    OMERO server address and user login information must be stored in the user
    PYME config directory under plugins/config/pyme-omero, e.g. 
    /Users/Andrew/.PYME/plugins/config/pyme-omero. The file should be yaml
    formatted with the following keys:
        user
        password
        address
        port [optional]
    
    The project/image/dataset will be owned by the user set in the yaml file.
    """
    filePattern = '{file_stub}.png'
    scaling = Enum(['min-max', 'percentile'])
    scaling_factor = Float(0.95)
    zoom = Int(1)
    colorblind_friendly = Bool(True)

    def _save(self, image, path):
        from PIL import Image
        from PYME.IO.rgb_image import image_to_rgb, image_to_cmy

        if (self.colorblind_friendly and (image.data.shape[3] != 1)):
            im = image_to_cmy(image,
                              zoom=self.zoom,
                              scaling=self.scaling,
                              scaling_factor=self.scaling_factor)
        else:
            im = image_to_rgb(image,
                              zoom=self.zoom,
                              scaling=self.scaling,
                              scaling_factor=self.scaling_factor)

        rgb = Image.fromarray(im, mode='RGB')
        rgb.save(path)
class ImageUpload(OutputModule):
    """
    Upload a PYME ImageStack to an OMERO server, optionally
    attaching localization files.

    Parameters
    ----------
    input_image : str
        name of image in the recipe namespace to upload
    input_localization_attachments : dict
        maps tabular types (keys) to attachment filenames (values). Tabular's
        will be saved as '.hdf' files and attached to the image
    filePattern : str
        pattern to determine name of image on OMERO server. 'file_stub' will be
        set automatically.
    omero_dataset : str
        name of OMERO dataset to add the image to. If the dataset does not
        already exist it will be created. Can use sample metadata entries
        using {format} syntax
    omero_project : str
        name of OMERO project to link the dataset to. If the project does not
        already exist it will be created. Can use sample metadata entries
        using {format} syntax
    
    Notes
    -----
    OMERO server address and user login information must be stored in the user
    PYME config directory under plugins/config/pyme-omero, e.g. 
    /Users/Andrew/.PYME/plugins/config/pyme-omero. The file should be yaml
    formatted with the following keys:
        user
        password
        address
        port [optional]
    
    The project/image/dataset will be owned by the user set in the yaml file.
    """

    input_image = Input('')
    input_localization_attachments = DictStrStr()

    filePattern = '{file_stub}.tif'

    scheme = Enum(['OMERO'])

    omero_project = CStr('')
    omero_dataset = CStr('{Sample.SlideRef}')

    def _save(self, image, path):
        # force tif extension
        path = os.path.splitext(path)[0] + '.tif'
        image.Save(path)

    def save(self, namespace, context={}):
        """
        Parameters
        ----------
        namespace : dict
            The recipe namespace
        context : dict
            Information about the source file to allow pattern substitution to 
            generate the output name. At least 'file_stub' (which is the 
            filename without any extension) should be resolved.

        """
        from pyme_omero import core
        from tempfile import TemporaryDirectory

        out_filename = self.filePattern.format(**context)

        im = namespace[self.input_image]

        if hasattr(im, 'mdh'):
            sample = Sample()  # hack around our md keys having periods in them
            for k in [k for k in im.mdh.keys() if k.startswith('Sample.')]:
                setattr(sample, k.split('Sample.')[-1], im.mdh[k])
            sample_md = dict(Sample=sample)
        else:
            sample_md = {}

        dataset = self.omero_dataset.format(**sample_md)
        project = self.omero_project.format(**sample_md)

        with TemporaryDirectory() as temp_dir:
            out_filename = os.path.join(temp_dir, out_filename)
            self._save(im, out_filename)

            loc_filenames = []
            for loc_key, loc_stub in self.input_localization_attachments.items(
            ):
                loc_filename = os.path.join(temp_dir, loc_stub)
                if os.path.splitext(loc_filename)[-1] == '':
                    # default to hdf unless h5r is manually specified
                    loc_filename = loc_filename + '.hdf'
                loc_filenames.append(loc_filename)
                try:
                    mdh = namespace[loc_key].mdh
                except AttributeError:
                    mdh = None
                namespace[loc_key].to_hdf(loc_filename, loc_key, metadata=mdh)

            image_id = core.upload_image_from_file(out_filename, dataset,
                                                   project, loc_filenames)

        # if an h5r file is the principle input, upload it
        try:
            principle = os.path.join(context['input_dir'],
                                     context['file_stub']) + '.h5r'
            with core.local_or_named_temp_filename(principle) as f:
                core.connect_and_upload_file_annotation(
                    image_id, f, namespace='pyme.localizations')
        except (KeyError, IOError):
            pass

    @property
    def inputs(self):
        return set(self.input_localization_attachments.keys()).union(
            set([self.input_image]))

    @property
    def default_view(self):
        import wx
        if wx.GetApp() is None:
            return None

        from traitsui.api import View, Item
        from PYME.ui.custom_traits_editors import DictChoiceStrEditor, CBEditor

        inputs, outputs, params = self.get_params()

        return View([
            Item(name='input_image',
                 editor=CBEditor(choices=self._namespace_keys)),
        ] + [
            Item(name='input_localization_attachments',
                 editor=DictChoiceStrEditor(choices=self._namespace_keys)),
        ] + [
            Item('_'),
        ] + self._view_items(params),
                    buttons=['OK', 'Cancel'])

    @property
    def pipeline_view(self):
        return self.default_view

    @property
    def no_localization_view(self):
        import wx
        if wx.GetApp() is None:
            return None

        from traitsui.api import View, Item
        from PYME.ui.custom_traits_editors import CBEditor

        inputs, outputs, params = self.get_params()

        return View([
            Item(name='input_image',
                 editor=CBEditor(choices=self._namespace_keys)),
        ] + [
            Item('_'),
        ] + self._view_items(params),
                    buttons=['OK', 'Cancel'])
Beispiel #10
0
class CalculateFRCFromLocs(CalculateFRCBase):
    """
    Generates a pair of images from localization data and calculates the fourier shell/ring correlation (FSC / FRC).
    
    Inputs
    ------
    inputName : TabularBase
        Localization data.
        
    Outputs
    -------
    outputName : TabularBase
        Localization data labeled with how it was divided (FRC_group).
    output_images : ImageStack
        Pair of 2D or 3D histogram rendered images.
    output_fft_images_cc : ImageStack
        Fast Fourier transform original and cross-correlation images.
    output_frc_dict : dict
        FSC/FRC results.
    output_frc_plot : Plot
        Output plot of the FSC / FRC curve.
    output_frc_raw : dict
        Complete FSC/FRC results.
    
    Parameters
    ----------
    split_method : string
        Different methods of dividing data into halves.
    pixel_size_in_nm : float
        Pixel size used for rendering the images.
    flatten_z : Bool
        If enabled ignores z information and only performs a FRC.
    pre_filter : string
        Methods to filter the images prior to Fourier transform.
    frc_smoothing_func : string
        Methods to smooth the FSC / FRC curve.
    cubic_smoothing  : float
        Smoothing factor for cubic spline.
    multiprocessing : Bool
        Enables multiprocessing.
    save_path : File
        (Optional) File path to save output
    
    """
    inputName = Input('Localizations')
    split_method = Enum(['halves_random', 'halves_time', 'halves_100_time_chunk', 'halves_10_time_chunk', 'fixed_time', 'fixed_10_time_chunk'])    
    pixel_size_in_nm = Float(5)
#    pre_filter = Enum(['Tukey_1/8', None])
#    frc_smoothing_func = Enum(['Cubic Spline', 'Sigmoid', None])
#    plot_graphs = Bool()
#    multiprocessing =  Bool(True)
    flatten_z = Bool(True)
    
    outputName = Output('FRC_ds')
    output_images = Output('FRC_images')
#    output_fft_images = Output('FRC_fft_images')
#    output_frc_dict = Output('FRC_dict')
#    output_frc_plot = Output('FRC_plot')
    
    def execute(self, namespace):
        self._namespace = namespace
        import multiprocessing
#        from PYME.util import mProfile        
#        mProfile.profileOn(["frc.py"])
        
        if self.multiprocessing:
            proccess_count = np.clip(2, 1, multiprocessing.cpu_count()-1)
            self._pool = multiprocessing.Pool(processes=proccess_count)
        
        pipeline = namespace[self.inputName]
        mapped_pipeline = tabular.mappingFilter(pipeline)
        self._pixel_size_in_nm = self.pixel_size_in_nm * np.ones(3, dtype=np.float)
        
        image_pair = self.generate_image_pair(mapped_pipeline)
        
        image_pair = self.preprocess_images(image_pair)
            
        # Should use DensityMapping recipe eventually when it is ready.
        mdh = MetaDataHandler.NestedClassMDHandler()
        mdh['Rendering.Method'] = "np.histogramdd"
        if 'imageID' in pipeline.mdh.getEntryNames():
            mdh['Rendering.SourceImageID'] = pipeline.mdh['imageID']
        try:
            mdh['Rendering.SourceFilename'] = pipeline.resultsSource.h5f.filename
        except:
            pass        
        mdh.Source = MetaDataHandler.NestedClassMDHandler(pipeline.mdh)        
        mdh['Rendering.NEventsRendered'] = [image_pair[0].sum(), image_pair[1].sum()]
        mdh['voxelsize.units'] = 'um'
        mdh['voxelsize.x'] = self.pixel_size_in_nm * 1E-3
        mdh['voxelsize.y'] = self.pixel_size_in_nm * 1E-3
        
        ims = ImageStack(data=np.stack(image_pair, axis=-1), mdh=mdh)
        namespace[self.output_images] = ims
        
#        if self.plot_graphs:
#            from PYME.DSView.dsviewer import ViewIm3D
#            ViewIm3D(ims)
        
        frc_res, rawdata = self.calculate_FRC_from_images(image_pair, pipeline.mdh)
        
#        smoothed_frc = self.SmoothFRC(frc_freq, frc_corr)
#        
#        self.CalculateThreshold(frc_freq, frc_corr, smoothed_frc)
        
        namespace[self.output_frc_dict] = frc_res
        namespace[self.output_frc_raw] = rawdata
        
        if self.multiprocessing:
            self._pool.close()
            self._pool.join()
        
#        mProfile.profileOff()
#        mProfile.report()
        
        self.save_to_file(namespace)
        
    def generate_image_pair(self, mapped_pipeline):
        # Split localizations into 2 sets
        mask = np.zeros(mapped_pipeline['t'].shape, dtype=np.bool)        
        if self.split_method == 'halves_time':
            sort_arg = np.argsort(mapped_pipeline['t'])
            mask[sort_arg[:len(sort_arg)//2]] = 1
        elif self.split_method == 'halves_random':
            mask[:len(mask)//2] = 1
            np.random.shuffle(mask)            
        elif self.split_method == 'halves_100_time_chunk':
            sort_arg = np.argsort(mapped_pipeline['t'])
            chunksize = mask.shape[0] / 100.0
            for i in range(50):
                mask[sort_arg[int(np.round(i*2*chunksize)):int(np.round((i*2+1)*chunksize))]] = 1
        elif self.split_method == 'halves_10_time_chunk':
            sort_arg = np.argsort(mapped_pipeline['t'])
            chunksize = mask.shape[0] * 0.1
            for i in range(5):
                mask[sort_arg[int(np.round(i*2*chunksize)):int(np.round((i*2+1)*chunksize))]] = 1
        elif self.split_method == 'fixed_time':
            time_cutoff = (mapped_pipeline['t'].ptp() + 1) // 2 + mapped_pipeline['t'].min()
            mask[mapped_pipeline['t'] < time_cutoff] = 1
        elif self.split_method == 'fixed_10_time_chunk':
            time_cutoffs = np.linspace(mapped_pipeline['t'].min(), mapped_pipeline['t'].max(), 11, dtype=np.float)
            for i in range((time_cutoffs.shape[0]-1)//2):
                mask[(mapped_pipeline['t'] > time_cutoffs[i*2]) & (mapped_pipeline['t'] < time_cutoffs[i*2+1])] = 1
        
        if self.split_method.startswith('halves_'):
            assert np.abs(mask.sum() - (~mask).sum()) <= 1, "datasets uneven, {} vs {}".format(mask.sum(), (~mask).sum())
        else:
            print("variable counts between images, {} vs {}".format(mask.sum(), (~mask).sum()))
        
        mapped_pipeline.addColumn("FRC_group", mask)
        self._namespace[self.outputName] = mapped_pipeline        
        
        dims = ['x', 'y', 'z']
        if self.flatten_z:
            dims.remove('z')
            
        # Simple hist2d binning
        bins = list()
        data = list()
        for i, dim in enumerate(dims):
            bins.append(np.arange(np.floor(mapped_pipeline[dim].min()).astype(int), np.ceil(mapped_pipeline[dim].max()).astype(int) + self.pixel_size_in_nm, self.pixel_size_in_nm))
            
            data.append(mapped_pipeline[dim])
        
        data = np.stack(data, axis=1)
        # print(bins)
        
        if self.multiprocessing:
            results = list()
            results.append(self._pool.apply_async(np.histogramdd, (data[mask==0],), kwds= {"bins":bins}))
            results.append(self._pool.apply_async(np.histogramdd, (data[mask==1],), kwds= {"bins":bins}))
            
            image_a = results[0].get()[0]
            image_b = results[1].get()[0]                
        else:
            image_a = np.histogramdd(data[mask==0], bins=bins)[0]
            image_b = np.histogramdd(data[mask==1], bins=bins)[0]
            
        return image_a, image_b
Beispiel #11
0
class TriangleRenderLayer(EngineLayer):
    """
    Layer for viewing triangle meshes.
    """
    # properties to show in the GUI. Note that we also inherit 'visible' from BaseLayer
    vertexColour = CStr('constant', desc='Name of variable used to colour our points')
    cmap = Enum(*cm.cmapnames, default='gist_rainbow', desc='Name of colourmap used to colour faces')
    clim = ListFloat([0, 1], desc='How our variable should be scaled prior to colour mapping')
    alpha = Float(1.0, desc='Face tranparency')
    method = Enum(*ENGINES.keys(), desc='Method used to display faces')
    normal_mode = Enum(['Per vertex', 'Per face'])
    dsname = CStr('output', desc='Name of the datasource within the pipeline to use as a source of triangles (should be a TriangularMesh object)')
    _datasource_choices = List()
    _datasource_keys = List()

    def __init__(self, pipeline, method='wireframe', dsname='', context=None, **kwargs):
        EngineLayer.__init__(self, context=context, **kwargs)
        self._pipeline = pipeline
        self.engine = None
        self.cmap = 'gist_rainbow'

        self.x_key = 'x'  # TODO - make these traits?
        self.y_key = 'y'
        self.z_key = 'z'

        self.xn_key = 'xn'
        self.yn_key = 'yn'
        self.zn_key = 'zn'

        self._bbox = None

        # define a signal so that people can be notified when we are updated (currently used to force a redraw when
        # parameters change)
        self.on_update = dispatch.Signal()

        # define responses to changes in various traits
        self.on_trait_change(self._update, 'vertexColour')
        self.on_trait_change(lambda: self.on_update.send(self), 'visible')
        self.on_trait_change(self.update, 'cmap, clim, alpha, dsname, normal_mode')
        self.on_trait_change(self._set_method, 'method')

        # update any of our traits which were passed as command line arguments
        self.set(**kwargs)

        # update datasource and method
        self.dsname = dsname
        if self.method == method:
            #make sure we still call _set_method even if we start with the default method
            self._set_method()
        else:
            self.method = method

        # if we were given a pipeline, connect ourselves to the onRebuild signal so that we can automatically update
        # ourselves
        if not self._pipeline is None:
            self._pipeline.onRebuild.connect(self.update)

    @property
    def datasource(self):
        """
        Return the datasource we are connected to (does not go through the pipeline for triangles_mesh).
        """
        return self._pipeline.get_layer_data(self.dsname)
        #return self.datasource
    
    @property
    def _ds_class(self):
        # from PYME.experimental import triangle_mesh
        from PYME.experimental import _triangle_mesh as triangle_mesh
        return triangle_mesh.TrianglesBase

    def _set_method(self):
        self.engine = ENGINES[self.method](self._context)
        self.update()

    def _get_cdata(self):
        try:
            cdata = self.datasource[self.vertexColour]
        except (KeyError, TypeError):
            cdata = np.array([0, 1])

        return cdata

    def _update(self, *args, **kwargs):
        #pass
        cdata = self._get_cdata()
        self.clim = [float(cdata.min()), float(cdata.max())]
        self.update(*args, **kwargs)

    def update(self, *args, **kwargs):
        self._datasource_choices = [k for k, v in self._pipeline.dataSources.items() if isinstance(v, self._ds_class)]
        
        if not self.datasource is None:
            dks = ['constant',]
            if hasattr(self.datasource, 'keys'):
                 dks = dks + sorted(self.datasource.keys())
            self._datasource_keys = dks
        
        if not (self.engine is None or self.datasource is None):
            print('lw update')
            self.update_from_datasource(self.datasource)
            self.on_update.send(self)

    @property
    def bbox(self):
        return self._bbox

    def update_from_datasource(self, ds):
        """
        Pulls vertices/normals from a binary STL file. See PYME.IO.FileUtils.stl for more info. Calls update_data on the input.

        Parameters
        ----------
        ds :
            PYME.experimental.triangular_mesh.TriangularMesh object

        Returns
        -------
        None
        """
        #t = ds.vertices[ds.faces]
        #n = ds.vertex_normals[ds.faces]
        
        x, y, z = ds.vertices[ds.faces].reshape(-1, 3).T
        
        if self.normal_mode == 'Per vertex':
            xn, yn, zn = ds.vertex_normals[ds.faces].reshape(-1, 3).T
        else:
            xn, yn, zn = np.repeat(ds.face_normals.T, 3, axis=1)
            
        if self.vertexColour in ['', 'constant']:
            c = np.ones(len(x))
            clim = [0, 1]
        #elif self.vertexColour == 'vertex_index':
        #    c = np.arange(0, len(x))
        else:
            c = ds[self.vertexColour][ds.faces].ravel()
            clim = self.clim

        cmap = getattr(cm, self.cmap)
        alpha = float(self.alpha)

        # Do we have coordinates? Concatenate into vertices.
        if x is not None and y is not None and z is not None:
            vertices = np.vstack((x.ravel(), y.ravel(), z.ravel()))
            self._vertices = vertices.T.ravel().reshape(len(x.ravel()), 3)

            if not xn is None:
                self._normals = np.vstack((xn.ravel(), yn.ravel(), zn.ravel())).T.ravel().reshape(len(x.ravel()), 3)
            else:
                self._normals = -0.69 * np.ones(self._vertices.shape)

            self._bbox = np.array([x.min(), y.min(), z.min(), x.max(), y.max(), z.max()])
        else:
            self._bbox = None

        # TODO: This temporarily sets all triangles to the color red. User should be able to select color.
        if c is None:
            c = np.ones(self._vertices.shape[0]) * 255  # vector of pink
            
        

        if clim is not None and c is not None and cmap is not None:
            cs_ = ((c - clim[0]) / (clim[1] - clim[0]))
            cs = cmap(cs_)

            if self.method in ['flat', 'tessel']:
                alpha = cs_ * alpha
            
            cs[:, 3] = alpha
            
            if self.method == 'tessel':
                cs = np.power(cs, 0.333)

            self._colors = cs.ravel().reshape(len(c), 4)
        else:
            # cs = None
            if not self._vertices is None:
                self._colors = np.ones((self._vertices.shape[0], 4), 'f')
            
        self._alpha = alpha
        self._color_map = cmap
        self._color_limit = clim


    def get_vertices(self):
        return self._vertices

    def get_normals(self):
        return self._normals

    def get_colors(self):
        return self._colors

    def get_color_map(self):
        return self._color_map

    @property
    def colour_map(self):
        return self._color_map

    def get_color_limit(self):
        return self._color_limit

    @property
    def default_view(self):
        from traitsui.api import View, Item, Group, InstanceEditor, EnumEditor
        from PYME.ui.custom_traits_editors import HistLimitsEditor, CBEditor

        return View([Group([Item('dsname', label='Data', editor=EnumEditor(name='_datasource_choices')), ]),
                     Item('method'),
                     Item('normal_mode', visible_when='method=="shaded"'),
                     Item('vertexColour', editor=EnumEditor(name='_datasource_keys'), label='Colour'),
                     Group([Item('clim', editor=HistLimitsEditor(data=self._get_cdata), show_label=False), ], visible_when='vertexColour != "constant"'),
                     Group([Item('cmap', label='LUT'),
                            Item('alpha', visible_when='method in ["flat", "tessel"]')
                            ])
                     ], )
        # buttons=['OK', 'Cancel'])

    def default_traits_view(self):
        return self.default_view
class AlignPSF(ModuleBase):
    """
        Align PSF stacks by redundant cross correlation.
    """
    
    inputName = Input('psf_cropped')
    normalize_z = Bool(True)
    tukey = Float(0.50)
    rcc_tolerance = Float(5.0)
    z_crop_half_roi = Int(15)
    peak_detect = Enum(['Gaussian', 'RBF'])
    debug = Bool(False)
    output_cross_corr_images = Output('cross_cor_img')
    output_cross_corr_images_fitted = Output('cross_cor_img_fitted')
    output_images = Output('psf_aligned')
    
    def execute(self, namespace):
        self._namespace = namespace
        ims = namespace[self.inputName]
        
        # X, Y, Z, 'C'
        psf_stack = ims.data[:,:,:,:]
        
        z_slice = slice(psf_stack.shape[2]//2-self.z_crop_half_roi, psf_stack.shape[2]//2+self.z_crop_half_roi+1)
        cleaned_psf_stack = self.normalize_images(psf_stack[:,:,z_slice,:])
            
        if self.tukey > 0:
            masks = [signal.tukey(dim_len, self.tukey) for dim_len in cleaned_psf_stack.shape[:3]]            
            masks = np.product(np.meshgrid(*masks, indexing='ij'), axis=0)            
            cleaned_psf_stack *= masks[:,:,:,None]
            
        drifts = self.calculate_shifts(cleaned_psf_stack, self.rcc_tolerance * 1E3 / ims.mdh['voxelsize.x'])
#        print drifts        
        
        namespace[self.output_images] = ImageStack(self.shift_images(cleaned_psf_stack if self.debug else psf_stack, drifts), mdh=ims.mdh)
        
    def normalize_images(self, psf_stack):
        # in case it is already bg subtracted
        cleaned_psf_stack = np.clip(psf_stack, 0, None)
        
        # substact bg per stack
        cleaned_psf_stack -= cleaned_psf_stack.min(axis=(0,1,2), keepdims=True)
        
        if self.normalize_z:
            # normalize intensity per plane
            cleaned_psf_stack /= cleaned_psf_stack.max(axis=(0,1), keepdims=True) / 1.05
        else:
            # normalize intensity per psf stack
            cleaned_psf_stack /= cleaned_psf_stack.max(axis=(0,1,2), keepdims=True) / 1.05
            
        cleaned_psf_stack -= 0.05
        np.clip(cleaned_psf_stack, 0, None, cleaned_psf_stack)
        
        return cleaned_psf_stack

    def calculate_shifts(self, psf_stack, drift_tolerance):
        n_steps = psf_stack.shape[3]
        coefs_size = n_steps * (n_steps-1) / 2
        coefs = np.zeros((coefs_size, n_steps-1))
        shifts = np.zeros((coefs_size, 3))
        
        output_cross_corr_images = np.zeros((psf_stack.shape[0], psf_stack.shape[1], psf_stack.shape[2], coefs_size), dtype=np.float)
        output_cross_corr_images_fitted = np.zeros((psf_stack.shape[0], psf_stack.shape[1], psf_stack.shape[2], coefs_size), dtype=np.float)
        
        counter = 0
        for i in np.arange(0, n_steps - 1):
            for j in np.arange(i+1, n_steps):
                coefs[counter, i:j] = 1
                
                print "compare {} to {}".format(i, j)
                correlate_result = signal.correlate(psf_stack[:,:,:,i], psf_stack[:,:,:,j], mode="same")
                correlate_result -= correlate_result.min()
                correlate_result /= correlate_result.max()
        
                threshold = 0.50
                correlate_result[correlate_result<threshold] = np.nan
                
                labeled_image, labeled_counts = ndimage.label(~np.isnan(correlate_result))
#                print(labeled_counts)
                # protects against > 1 peak in the cross correlation results
                # shouldn't happen anyway, but at least avoid fitting a single to multi-modal data
                if labeled_counts > 1:
                    max_order = np.argsort(ndimage.maximum(correlate_result, labeled_image, np.arange(labeled_counts)+1))+1
                    correlate_result[labeled_image!=max_order[0]] = np.nan
                    
                
                output_cross_corr_images[:,:,:,counter] = np.nan_to_num(correlate_result)
                
                dims = list()
                for _, dim in enumerate(correlate_result.shape):
                    dims.append(np.arange(dim))
                    dims[-1] = dims[-1] - dims[-1].mean()
                
#                peaks = np.nonzero(correlate_result==np.nanmax(correlate_result))
                if self.peak_detect == "Gaussian":
                    res = optimize.least_squares(guassian_nd_error,
                                                 [1, 0, 0, 5., 0, 5., 0, 30.],
                                                 args=(dims, correlate_result))                
                    output_cross_corr_images_fitted[:,:,:,counter] = gaussian_nd(res.x, dims)
#                    print("Gaussian")
#                    print("chi2: {}".format(np.sum(np.square(res.fun))/(res.fun.shape[0]-8)))
#                    print("fitted parameters: {}".format(res.x))
                        
    #                res = optimize.least_squares(guassian_sq_nd_error,
    #                                             [1, 0, 0, 3., 0, 3., 0, 20.],
    #                                             args=(dims, correlate_result))                
    #                output_cross_corr_images_fitted[:,:,:,counter] = gaussian_sq_nd(res.x, dims)
    #                print("Gaussian 2")
    #                print("chi2: {}".format(np.sum(np.square(res.fun))/(res.fun.shape[0]-8)))
    #                print("fitted parameters: {}".format(res.x))
    #                
    #                res = optimize.least_squares(lorentzian_nd_error,
    #                                             [1, 0, 0, 2., 0, 2., 0, 10.],
    #                                             args=(dims, correlate_result))                
    #                output_cross_corr_images_fitted[:,:,:,counter] = lorentzian_nd(res.x, dims)
    #                print("lorentzian")
    #                print("chi2: {}".format(np.sum(np.square(res.fun))/(res.fun.shape[0]-8)))
    #                print("fitted parameters: {}".format(res.x))
                    
                    shifts[counter, 0] = res.x[2]
                    shifts[counter, 1] = res.x[4]
                    shifts[counter, 2] = res.x[6]
                elif self.peak_detect == "RBF":
                    rbf_interpolator = build_rbf(dims, correlate_result)
                    res = optimize.minimize(rbf_nd_error, [correlate_result.shape[0]*0.5, correlate_result.shape[1]*0.5, correlate_result.shape[2]*0.5], args=rbf_interpolator)
                    output_cross_corr_images_fitted[:,:,:,counter] = rbf_nd(rbf_interpolator, dims)
#                    print(res.x)
                    shifts[counter, :] = res.x
                else:
                    raise Exception("peak founding method not recognised")

#                print("fitted parameters: {}".format(res.x))
        
                counter += 1
                
        self._namespace[self.output_cross_corr_images] = ImageStack(output_cross_corr_images)
        self._namespace[self.output_cross_corr_images_fitted] = ImageStack(output_cross_corr_images_fitted)
        
        drifts = np.matmul(np.linalg.pinv(coefs), shifts)
        residuals = np.matmul(coefs, drifts) - shifts
        residuals_dist = np.linalg.norm(residuals, axis=1)
        
#        shift_max = self.rcc_tolerance * 1E3 / mdh['voxelsize.x']
        shift_max = drift_tolerance
        # Sort and mask residual errors
        residuals_arg = np.argsort(-residuals_dist)        
        residuals_arg = residuals_arg[residuals_dist[residuals_arg] > shift_max]

        # Remove coefs rows
        # Descending from largest residuals to small
        # Only if matrix remains full rank
        coefs_temp = np.empty_like(coefs)
        counter = 0
        for i, index in enumerate(residuals_arg):
            coefs_temp[:] = coefs
            coefs_temp[index, :] = 0
            if np.linalg.matrix_rank(coefs_temp) == coefs.shape[1]:
                coefs[:] = coefs_temp
        #                print("index {} with residual of {} removed".format(index, residuals_dist[index]))
                counter += 1
            else:
                print("Could not remove all residuals over shift_max threshold.")
                break
        print("removed {} in total".format(counter))
        drifts = np.matmul(np.linalg.pinv(coefs), shifts)
       
        drifts = np.pad(drifts, [[1,0],[0,0]], 'constant', constant_values=0)
        np.cumsum(drifts, axis=0, out=drifts)
        
        psf_stack_mean = psf_stack / psf_stack.mean(axis=(0,1,2), keepdims=True)
        psf_stack_mean = psf_stack_mean.mean(axis=3)
        psf_stack_mean *= psf_stack_mean > psf_stack_mean.max() * 0.5        
        center_offset = ndimage.center_of_mass(psf_stack_mean) - np.asarray(psf_stack_mean.shape)*0.5
#        print(center_offset)
        
#        print drifts.shape
#        print stats.trim_mean(drifts, 0.25, axis=0)
#        drifts = drifts - stats.trim_mean(drifts, 0.25, axis=0)
        
        drifts = drifts - center_offset
                
        if True:
            try:
#                from matplotlib import pyplot
                fig, axes = pyplot.subplots(1, 2, figsize=(6,3))
#                new_residuals = np.matmul(coefs, drifts) - shifts
#                new_residuals_dist = np.linalg.norm(new_residuals, axis=1)
#                # print new_residuals_dist
#                pyplot.hist(new_residuals_dist[coefs.any(axis=1)], 100)
#                print drifts                
#                limits = np.max(np.abs(drifts), axis=0)
                
                axes[0].scatter(drifts[:,0], drifts[:,1], s=50)
#                axes[0].set_xlim(-limits[0], limits[0])
#                axes[0].set_ylim(-limits[1], limits[1])
                axes[0].set_xlabel('x')
                axes[0].set_ylabel('y')                
                axes[1].scatter(drifts[:,0], drifts[:,2], s=50)
                axes[1].set_xlabel('x')
                axes[1].set_ylabel('z')
                
                for ax in axes:
#                    ax.set_xlim(-1, 1)
#                    ax.set_ylim(-1, 1)
                    ax.axvline(0, color='red', ls='--')
                    ax.axhline(0, color='red', ls='--')                    
                
                fig.tight_layout()
            except Exception as e:
                print e
            
            
        return drifts
    
    def shift_images(self, psf_stack, shifts):
        kx = (np.fft.fftfreq(psf_stack.shape[0])) 
        ky = (np.fft.fftfreq(psf_stack.shape[1]))
        kz = (np.fft.fftfreq(psf_stack.shape[2]))
        kx, ky, kz = np.meshgrid(kx, ky, kz, indexing='ij')
        
        shifted_images = np.zeros_like(psf_stack)
        for i in np.arange(psf_stack.shape[3]):
            psf = psf_stack[:,:,:,i]
            ft_image = np.fft.fftn(psf)
            shift = shifts[i]
            shifted_images[:,:,:,i] = np.abs(np.fft.ifftn(ft_image*np.exp(-2j*np.pi*(kx*shift[0] + ky*shift[1] + kz*shift[2]))))
#            shifted_images.append(shifted_image)
        
        return shifted_images
Beispiel #13
0
class FiducialTrack(ModuleBase):
    """
    Extract average fiducial track from input pipeline

    Parameters
    ----------

        radiusMultiplier: this number is multiplied with error_x to obtain search radius for clustering
        timeWindow: the window along the time dimension used for clustering
        filterScale: the size of the filter kernel used to smooth the resulting average fiducial track
        filterMethod: enumrated choice of filter methods for smoothing operation (Gaussian, Median or Uniform kernel)

    Notes
    -----

    Output is a new pipeline with added fiducial_x, fiducial_y columns

    """
    import PYMEcs.Analysis.trackFiducials as tfs
    inputName = Input('filtered')

    radiusMultiplier = Float(5.0)
    timeWindow = Int(25)
    filterScale = Float(11)
    filterMethod = Enum(tfs.FILTER_FUNCS.keys())
    clumpMinSize = Int(50)
    singleFiducial = Bool(True)

    outputName = Output('fiducialAdded')

    def execute(self, namespace):
        import PYMEcs.Analysis.trackFiducials as tfs

        inp = namespace[self.inputName]
        mapped = tabular.mappingFilter(inp)

        if self.singleFiducial:
            # if all data is from a single fiducial we do not need to align
            # we then avoid problems with incomplete tracks giving rise to offsets between
            # fiducial track fragments
            align = False
        else:
            align = True

        t, x, y, z, isFiducial = tfs.extractTrajectoriesClump(
            inp,
            clumpRadiusVar='error_x',
            clumpRadiusMultiplier=self.radiusMultiplier,
            timeWindow=self.timeWindow,
            clumpMinSize=self.clumpMinSize,
            align=align)
        rawtracks = (t, x, y, z)
        tracks = tfs.AverageTrack(inp,
                                  rawtracks,
                                  filter=self.filterMethod,
                                  filterScale=self.filterScale,
                                  align=align)

        # add tracks for all calculated dims to output
        for dim in tracks.keys():
            mapped.addColumn('fiducial_%s' % dim, tracks[dim])
        mapped.addColumn('isFiducial', isFiducial)

        # propogate metadata, if present
        try:
            mapped.mdh = inp.mdh
        except AttributeError:
            pass

        namespace[self.outputName] = mapped

    @property
    def hide_in_overview(self):
        return ['columns']
Beispiel #14
0
class TrackRenderLayer(EngineLayer):
    """
    A layer for viewing tracking data

    """

    # properties to show in the GUI. Note that we also inherit 'visible' from BaseLayer
    vertexColour = CStr('', desc='Name of variable used to colour our points')
    cmap = Enum(*cm.cmapnames,
                default='gist_rainbow',
                desc='Name of colourmap used to colour points')
    clim = ListFloat(
        [0, 1],
        desc='How our variable should be scaled prior to colour mapping')
    alpha = Float(1.0, desc='Tranparency')
    line_width = Float(1.0, desc='Track line width')
    method = Enum(*ENGINES.keys(), desc='Method used to display tracks')
    clump_key = CStr('clumpIndex',
                     desc="Name of column containing the track identifier")
    dsname = CStr(
        'output',
        desc=
        'Name of the datasource within the pipeline to use as a source of points'
    )
    _datasource_keys = List()
    _datasource_choices = List()

    def __init__(self,
                 pipeline,
                 method='tracks',
                 dsname='',
                 context=None,
                 **kwargs):
        EngineLayer.__init__(self, context=context, **kwargs)
        self._pipeline = pipeline
        self.engine = None
        self.cmap = 'gist_rainbow'

        self.x_key = 'x'  #TODO - make these traits?
        self.y_key = 'y'
        self.z_key = 'z'

        self._bbox = None

        # define a signal so that people can be notified when we are updated (currently used to force a redraw when
        # parameters change)
        self.on_update = dispatch.Signal()

        # define responses to changes in various traits
        self.on_trait_change(self._update, 'vertexColour')
        self.on_trait_change(lambda: self.on_update.send(self), 'visible')
        self.on_trait_change(self.update,
                             'cmap, clim, alpha, dsname, clump_key')
        self.on_trait_change(self._set_method, 'method')

        # update any of our traits which were passed as command line arguments
        self.set(**kwargs)

        # update datasource name and method
        #logger.debug('Setting dsname and method')
        self.dsname = dsname
        self.method = method

        self._set_method()

        # if we were given a pipeline, connect ourselves to the onRebuild signal so that we can automatically update
        # ourselves
        if not self._pipeline is None:
            self._pipeline.onRebuild.connect(self.update)

    @property
    def datasource(self):
        """
        Return the datasource we are connected to (through our dsname property).
        """
        return self._pipeline.get_layer_data(self.dsname)

    def _set_method(self):
        #logger.debug('Setting layer method to %s' % self.method)
        self.engine = ENGINES[self.method](self._context)
        self.update()

    def _get_cdata(self):
        try:
            if isinstance(self.datasource, ClumpManager):
                cdata = []
                for track in self.datasource.all:
                    cdata.extend(track[self.vertexColour])
                cdata = np.array(cdata)
            else:
                # Assume tabular dataset
                cdata = self.datasource[self.vertexColour]
        except KeyError:
            cdata = np.array([0, 1])

        return cdata

    def _update(self, *args, **kwargs):
        cdata = self._get_cdata()
        self.clim = [float(np.nanmin(cdata)), float(np.nanmax(cdata))]
        #self.update(*args, **kwargs)

    def update(self, *args, **kwargs):
        print('lw update')
        self._datasource_choices = self._pipeline.layer_data_source_names
        if not self.datasource is None:
            if isinstance(self.datasource, ClumpManager):
                # Grab the keys from the first Track in the ClumpManager
                self._datasource_keys = sorted(self.datasource[0].keys())
            else:
                # Assume we have a tabular data source
                self._datasource_keys = sorted(self.datasource.keys())

        if not (self.engine is None or self.datasource is None):
            self.update_from_datasource(self.datasource)
            self.on_update.send(self)

    @property
    def bbox(self):
        return self._bbox

    def update_from_datasource(self, ds):
        if isinstance(ds, ClumpManager):
            x = []
            y = []
            z = []
            c = []
            self.clumpSizes = []

            # Copy data from tracks. This is already in clump order
            # thanks to ClumpManager
            for track in ds.all:
                x.extend(track['x'])
                y.extend(track['y'])
                z.extend(track['z'])
                self.clumpSizes.append(track.nEvents)

                if not self.vertexColour == '':
                    c.extend(track[self.vertexColour])
                else:
                    c.extend([0 for i in track['x']])

            x = np.array(x)
            y = np.array(y)
            z = np.array(z)
            c = np.array(c)

            # print(x,y,z,c)
            # print(x.shape,y.shape,z.shape,c.shape)

        else:
            # Assume tabular data source
            x, y = ds[self.x_key], ds[self.y_key]

            if not self.z_key is None:
                try:
                    z = ds[self.z_key]
                except KeyError:
                    z = 0 * x
            else:
                z = 0 * x

            if not self.vertexColour == '':
                c = ds[self.vertexColour]
            else:
                c = 0 * x

            # Work out clump start and finish indices
            # TODO - optimize / precompute????
            ci = ds[self.clump_key]

            NClumps = int(ci.max())

            clist = [[] for i in range(NClumps)]
            for i, cl_i in enumerate(ci):
                clist[int(cl_i - 1)].append(i)

            # This and self.clumpStarts are class attributes for
            # compatibility with the old Track rendering layer,
            # PYME.LMVis.gl_render3D.TrackLayer
            self.clumpSizes = [len(cl_i) for cl_i in clist]

            #reorder x, y, z, c in clump order
            I = np.hstack([np.array(cl) for cl in clist]).astype(np.int)

            x = x[I]
            y = y[I]
            z = z[I]
            c = c[I]

        self.clumpStarts = np.cumsum([
            0,
        ] + self.clumpSizes)

        #do normal vertex stuff
        vertices = np.vstack((x.ravel(), y.ravel(), z.ravel()))
        vertices = vertices.T.ravel().reshape(len(x.ravel()), 3)

        self._vertices = vertices
        self._normals = -0.69 * np.ones(vertices.shape)
        self._bbox = np.array(
            [x.min(), y.min(),
             z.min(), x.max(),
             y.max(), z.max()])

        clim = self.clim
        cmap = getattr(cm, self.cmap)

        if clim is not None:
            cs_ = ((c - clim[0]) / (clim[1] - clim[0]))
            cs = cmap(cs_)
            cs[:, 3] = float(self.alpha)

            self._colors = cs.ravel().reshape(len(c), 4)
        else:
            if not vertices is None:
                self._colors = np.ones((vertices.shape[0], 4), 'f')

        self._color_map = cmap
        self._color_limit = clim
        self._alpha = float(self.alpha)

    def get_vertices(self):
        return self._vertices

    def get_normals(self):
        return self._normals

    def get_colors(self):
        return self._colors

    def get_color_map(self):
        return self._color_map

    @property
    def colour_map(self):
        return self._color_map

    def get_color_limit(self):
        return self._color_limit

    @property
    def default_view(self):
        from traitsui.api import View, Item, Group, InstanceEditor, EnumEditor, TextEditor
        from PYME.ui.custom_traits_editors import HistLimitsEditor, CBEditor

        return View([
            Group([
                Item('dsname',
                     label='Data',
                     editor=EnumEditor(name='_datasource_choices')),
            ]),
            Item('method'),
            Item(
                'vertexColour',
                editor=EnumEditor(name='_datasource_keys'),
                label='Colour',
                visible_when='cmap not in ["R", "G", "B", "C", "M","Y", "K"]'),
            Group(
                [
                    Item('clim',
                         editor=HistLimitsEditor(data=self._get_cdata,
                                                 update_signal=self.on_update),
                         show_label=False),
                ],
                visible_when='cmap not in ["R", "G", "B", "C", "M","Y", "K"]'),
            Group(
                Item('cmap', label='LUT'),
                Item('alpha',
                     editor=TextEditor(auto_set=False,
                                       enter_set=True,
                                       evaluate=float)), Item('line_width'))
        ])
        #buttons=['OK', 'Cancel'])

    def default_traits_view(self):
        return self.default_view
class RCCDriftCorrectionBase(CacheCleanupModule):
    """    
    Performs drift correction using redundant cross-correlation from
    Wang et al. Optics Express 2014 22:13 (Bo Huang's RCC algorithm).
    Base class for other RCC recipes.
    Can take cached fft input (as filename, not an 'input').
    Only output drift as tuple of time points, and drift amount.
    Currently not registered by itself since not very usefule.
    """

    cache_fft = File("rcc_cache.bin")
    method = Enum(['RCC', 'MCC', 'DCC'])
    # redundant cross-corelation, mean cross-correlation, direct cross-correlation
    shift_max = Float(5)  # nm
    corr_window = Int(5)
    multiprocessing = Bool()
    debug_cor_file = File()

    output_drift = Output('drift')
    output_drift_plot = Output('drift_plot')

    # if debug_cor_file not blank, filled with imagestack of cross correlation
    output_cross_cor = Output('cross_cor')

    def calc_corr_drift_from_ft_images(self, ft_images):
        n_steps = ft_images.shape[0]

        # Matrix equation coefficient matrix
        # Shape can be predetermined based on method
        if self.method == "DCC":
            coefs_size = n_steps - 1
        elif self.corr_window > 0:
            coefs_size = n_steps * self.corr_window - self.corr_window * (
                self.corr_window + 1) // 2
        else:
            coefs_size = n_steps * (n_steps - 1) // 2
        coefs = np.zeros((coefs_size, n_steps - 1))
        shifts = np.zeros((coefs_size, 3))

        counter = 0

        ft_1_cache = list()
        ft_2_cache = list()
        autocor_shift_cache = list()

        #        print self.debug_cor_file
        if not self.debug_cor_file == "":
            cc_file_shape = [
                shifts.shape[0], ft_images.shape[1], ft_images.shape[2],
                (ft_images.shape[3] - 1) * 2
            ]

            # flatten shortest dimension to reduce cross correlation to 2d images for easier debugging
            min_arg = min(enumerate(cc_file_shape[1:]),
                          key=lambda x: x[1])[0] + 1
            cc_file_shape.pop(min_arg)
            cc_file_args = (self.debug_cor_file, np.float,
                            tuple(cc_file_shape))
            cc_file = np.memmap(cc_file_args[0],
                                dtype=cc_file_args[1],
                                mode="w+",
                                shape=cc_file_args[2])
            #            del cc_file
            cc_args = zip(range(shifts.shape[0]),
                          (cc_file_args, ) * shifts.shape[0])
        else:
            cc_args = (None, ) * shifts.shape[0]

        # For each ft image, calculate correlation
        for i in np.arange(0, n_steps - 1):
            if self.method == "DCC" and i > 0:
                break

            ft_1 = ft_images[i, :, :]

            autocor_shift = calc_shift(ft_1, ft_1)

            for j in np.arange(i + 1, n_steps):
                if (self.method != "DCC") and (self.corr_window > 0) and (
                        j - i > self.corr_window):
                    break

                ft_2 = ft_images[j, :, :]

                coefs[counter, i:j] = 1

                # if multiprocessing, use cache when defined
                if self.multiprocessing:
                    # if reading ft_images from cache, replace ft_1 and ft_2 with their indices
                    if not self.cache_fft == "":
                        ft_1 = i
                        ft_2 = j

                    ft_1_cache.append(ft_1)
                    ft_2_cache.append(ft_2)
                    autocor_shift_cache.append(autocor_shift)
                else:
                    shifts[counter, :] = calc_shift(ft_1, ft_2, autocor_shift,
                                                    None, cc_args[counter])

                    if ((counter + 1) % max(coefs_size // 5, 1) == 0):
                        print(
                            "{:.2f} s. Completed calculating {} of {} total shifts."
                            .format(time.time() - self._start_time,
                                    counter + 1, coefs_size))

                counter += 1

        if self.multiprocessing:
            args = zip(
                range(len(autocor_shift_cache)), ft_1_cache, ft_2_cache,
                autocor_shift_cache,
                len(ft_1_cache) *
                ((self.cache_fft, ft_images.dtype, ft_images.shape), ),
                cc_args)
            for i, (j, res) in enumerate(
                    self._pool.imap_unordered(calc_shift_helper, args)):
                shifts[j, ] = res

                if ((i + 1) % max(coefs_size // 5, 1) == 0):
                    print(
                        "{:.2f} s. Completed calculating {} of {} total shifts."
                        .format(time.time() - self._start_time, i + 1,
                                coefs_size))

        print("{:.2f} s. Finished calculating all shifts.".format(
            time.time() - self._start_time))
        print("{:,} bytes".format(coefs.nbytes))
        print("{:,} bytes".format(shifts.nbytes))

        if not self.debug_cor_file == "":
            # move time axis for ImageStack
            cc_file = np.moveaxis(cc_file, 0, 2)
            self.trait_setq(**{"_cc_image": ImageStack(data=cc_file.copy())})
            del cc_file
        else:
            self.trait_setq(**{"_cc_image": None})

        assert (np.all(np.any(
            coefs, axis=1))), "Coefficient matrix filled less than expected."

        mask = np.where(~np.isnan(shifts).any(axis=1))[0]
        if len(mask) < shifts.shape[0]:
            print("Removed {} cross correlations due to bad/missing data?".
                  format(shifts.shape[0] - len(mask)))
            coefs = coefs[mask, :]
            shifts = shifts[mask, :]

        assert (coefs.shape[0] > 0) and (
            np.linalg.matrix_rank(coefs) == n_steps -
            1), "Something went wrong with coefficient matrix. Not full rank."

        return shifts, coefs  # shifts.shape[0] is n_steps - 1

    def rcc(
        self,
        shift_max,
        t_shift,
        shifts,
        coefs,
    ):
        """
            Should probably rename function.
            Takes cross correlation results and calculates shifts.
        """

        print("{:.2f} s. About to start solving shifts array.".format(
            time.time() - self._start_time))

        # Estimate drift
        drifts = np.matmul(np.linalg.pinv(coefs), shifts)
        #        print(t_shift)
        #        print(drifts)

        print("{:.2f} s. Done solving shifts array.".format(time.time() -
                                                            self._start_time))

        if self.method == "RCC":

            # Calculate residual errors
            residuals = np.matmul(coefs, drifts) - shifts
            residuals_dist = np.linalg.norm(residuals, axis=1)

            # Sort and mask residual errors
            residuals_arg = np.argsort(-residuals_dist)
            residuals_arg = residuals_arg[
                residuals_dist[residuals_arg] > shift_max]

            # Remove coefs rows
            # Descending from largest residuals to small
            # Only if matrix remains full rank
            coefs_temp = np.empty_like(coefs)
            counter = 0
            for i, index in enumerate(residuals_arg):
                coefs_temp[:] = coefs
                coefs_temp[index, :] = 0
                if np.linalg.matrix_rank(coefs_temp) == coefs.shape[1]:
                    coefs[:] = coefs_temp
                    #                print("index {} with residual of {} removed".format(index, residuals_dist[index]))
                    counter += 1
                else:
                    print(
                        "Could not remove all residuals over shift_max threshold."
                    )
                    break
            print("removed {} in total".format(counter))

            # Estimate drift again
            drifts = np.matmul(np.linalg.pinv(coefs), shifts)

            print("{:.2f} s. RCC completed. Repeated solving shifts array.".
                  format(time.time() - self._start_time))

        # pad with 0 drift for first time point
        drifts = np.pad(drifts, [[1, 0], [0, 0]],
                        'constant',
                        constant_values=0)

        return t_shift, drifts

    def _execute(self, namespace):
        # dervied versions of RCC need to override this method
        # 'execute' of this RCC base class is not throughly tested as its use is probably quite limited.

        #        from PYME.util import mProfile

        self._start_time = time.time()
        print("Starting drift correction module.")

        if self.multiprocessing:
            proccess_count = np.clip(multiprocessing.cpu_count() - 1, 1, None)
            self._pool = multiprocessing.Pool(processes=proccess_count)


#        mProfile.profileOn(['localisations.py'])

        drift_res = self.calc_corr_drift_from_ft_images(self.cache_fft)
        t_shift, shifts = self.rcc(self.shift_max, *drift_res)
        #        mProfile.profileOff()
        #        mProfile.report()

        if self.multiprocessing:
            self._pool.close()
            self._pool.join()

        # convert frame-to-frame drift to drift from origin
        shifts = np.cumsum(shifts, 0)

        namespace[self.output_drift] = t_shift, shifts
Beispiel #16
0
class ImageRenderLayer(EngineLayer):
    """
    Layer for viewing images.
    """
    # properties to show in the GUI. Note that we also inherit 'visible' from BaseLayer
    cmap = Enum(*cm.cmapnames, default='gray', desc='Name of colourmap used to colour faces')
    clim = ListFloat([0, 1], desc='How our data should be scaled prior to colour mapping')
    alpha = Float(1.0, desc='Tranparency')
    method = Enum(*ENGINES.keys(), desc='Method used to display image')
    dsname = CStr('output', desc='Name of the datasource within the pipeline to use as an image')
    channel = Int(0)
    slice = Int(0)
    z_pos = Float(0)
    _datasource_choices = List()
    _datasource_keys = List()

    def __init__(self, pipeline, method='image', dsname='', display_opts=None, context=None, **kwargs):
        EngineLayer.__init__(self, context=context, **kwargs)
        self._pipeline = pipeline
        self.engine = None
        self.cmap = 'gray'

        self._bbox = None
        self._do = display_opts #a dh5view display_options instance - if provided, this over-rides the the clim, cmap properties
        
        self._im_key = None

        # define a signal so that people can be notified when we are updated (currently used to force a redraw when
        # parameters change)
        self.on_update = dispatch.Signal()

        # define responses to changes in various traits
        #self.on_trait_change(self._update, 'vertexColour')
        self.on_trait_change(lambda: self.on_update.send(self), 'visible')
        self.on_trait_change(self.update, 'cmap, clim, alpha, dsname')
        self.on_trait_change(self._set_method, 'method')

        # update any of our traits which were passed as command line arguments
        self.set(**kwargs)

        # update datasource and method
        self.dsname = dsname
        if self.method == method:
            #make sure we still call _set_method even if we start with the default method
            self._set_method()
        else:
            self.method = method

        # if we were given a pipeline, connect ourselves to the onRebuild signal so that we can automatically update
        # ourselves
        if (not self._pipeline is None) and hasattr(pipeline, 'onRebuild'):
            self._pipeline.onRebuild.connect(self.update)

    @property
    def datasource(self):
        """
        Return the datasource we are connected to (does not go through the pipeline for triangles_mesh).
        """
        try:
            return self._pipeline.get_layer_data(self.dsname)
        except AttributeError:
            #fallback if pipeline is a dictionary
            return self._pipeline[self.dsname]
        #return self.datasource
    
    @property
    def _ds_class(self):
        # from PYME.experimental import triangle_mesh
        from PYME.IO import image
        return image.ImageStack

    def _set_method(self):
        self.engine = ENGINES[self.method](self._context)
        self.update()


    # def _update(self, *args, **kwargs):
    #     #pass
    #     cdata = self._get_cdata()
    #     self.clim = [float(cdata.min()), float(cdata.max())]
    #     self.update(*args, **kwargs)

    def update(self, *args, **kwargs):
        try:
            self._datasource_choices = [k for k, v in self._pipeline.dataSources.items() if isinstance(v, self._ds_class)]
        except AttributeError:
            self._datasource_choices = [k for k, v in self._pipeline.items() if
                                        isinstance(v, self._ds_class)]
        
        if not (self.engine is None or self.datasource is None):
            print('lw update')
            self.update_from_datasource(self.datasource)
            self.on_update.send(self)

    @property
    def bbox(self):
        return self._bbox
    
    def sync_to_display_opts(self, do=None):
        if (do is None):
            if not (self._do is None):
                do = self._do
            else:
                return

        o = do.Offs[self.channel]
        g = do.Gains[self.channel]
        clim = [o, o + 1.0 / g]

        cmap = do.cmaps[self.channel].name
        visible = do.show[self.channel]
        
        self.set(clim=clim, cmap=cmap, visible=visible)
        

    def update_from_datasource(self, ds):
        """

        Parameters
        ----------
        ds :
            PYME.IO.image.ImageStack object

        Returns
        -------
        None
        """

        
        #if self._do is not None:
            # Let display options (if provied) over-ride our settings (TODO - is this the right way to do this?)
        #    o = self._do.Offs[self.channel]
        #    g = self._do.Gains[self.channel]
        #    clim = [o, o + 1.0/g]
            #self.clim = clim
            
        #    cmap = self._do.cmaps[self.channel]
            #self.visible = self._do.show[self.channel]
        #else:
        
        clim = self.clim
        cmap = getattr(cm, self.cmap)
            
        alpha = float(self.alpha)
        
        c0, c1 = clim
        
        im_key = (self.dsname, self.slice, self.channel)
        
        if not self._im_key == im_key:
            self._im_key = im_key
            self._im = ds.data[:,:,self.slice, self.channel].astype('f4')# - c0)/(c1-c0)
        
            x0, y0, x1, y1, _, _ = ds.imgBounds.bounds

            self._bbox = np.array([x0, y0, 0, x1, y1, 0])
        
            self._bounds = [x0, y0, x1, y1]
            
        self._alpha = alpha
        self._color_map = cmap
        self._color_limit = clim

    def get_color_map(self):
        return self._color_map

    @property
    def colour_map(self):
        return self._color_map

    def get_color_limit(self):
        return self._color_limit
    
    def _get_cdata(self):
        return self._im.ravel()[::20]

    @property
    def default_view(self):
        from traitsui.api import View, Item, Group, InstanceEditor, EnumEditor
        from PYME.ui.custom_traits_editors import HistLimitsEditor, CBEditor

        return View([Group([Item('dsname', label='Data', editor=EnumEditor(name='_datasource_choices')), ]),
                     #Item('method'),
                     Group([Item('clim', editor=HistLimitsEditor(data=self._get_cdata), show_label=False), ]),
                     Group([Item('cmap', label='LUT'),
                            Item('alpha', visible_when='method in ["flat", "tessel"]')
                            ])
                     ], )
        # buttons=['OK', 'Cancel'])

    def default_traits_view(self):
        return self.default_view
class LayerWrapper(HasTraits):
    cmap = Enum(*cm.cmapnames, default='gist_rainbow')
    clim = ListFloat([0, 1])
    alpha = Float(1.0)
    visible = Bool(True)
    method = Enum(*ENGINES.keys())
    engine = Instance(layers.BaseLayer)
    dsname = CStr('output')

    def __init__(self,
                 pipeline,
                 method='points',
                 ds_name='',
                 cmap='gist_rainbow',
                 clim=[0, 1],
                 alpha=1.0,
                 visible=True,
                 method_args={}):
        self._pipeline = pipeline
        #self._namespace=getattr(pipeline, 'namespace', {})
        #self.dsname = None
        self.engine = None

        self.cmap = cmap
        self.clim = clim
        self.alpha = alpha

        self.visible = visible

        self.on_update = dispatch.Signal()

        self.on_trait_change(lambda: self.on_update.send(self), 'visible')
        self.on_trait_change(self.update, 'cmap, clim, alpha, dsname')
        self.on_trait_change(self._set_method, 'method')

        #self.set_datasource(ds_name)
        self.dsname = ds_name

        self._eng_params = dict(method_args)
        self.method = method

        self._pipeline.onRebuild.connect(self.update)

    @property
    def _namespace(self):
        return self._pipeline.layer_datasources

    @property
    def bbox(self):
        return self.engine.bbox

    @property
    def colour_map(self):
        return self.engine.colour_map

    @property
    def data_source_names(self):
        names = []  #'']
        for k, v in self._namespace.items():
            names.append(k)
            if isinstance(v, tabular.ColourFilter):
                for c in v.getColourChans():
                    names.append('.'.join([k, c]))

        return names

    @property
    def datasource(self):
        if self.dsname == '':
            return self._pipeline

        parts = self.dsname.split('.')
        if len(parts) == 2:
            # special case - permit access to channels using dot notation
            # NB: only works if our underlying datasource is a ColourFilter
            ds, channel = parts
            return self._namespace.get(ds, None).get_channel_ds(channel)
        else:
            return self._namespace.get(self.dsname, None)

    def _set_method(self):
        if self.engine:
            self._eng_params = self.engine.get('point_size', 'vertexColour')
            #print(eng_params)

        self.engine = ENGINES[self.method](self._context)
        self.engine.set(**self._eng_params)
        self.engine.on_trait_change(self._update, 'vertexColour')
        self.engine.on_trait_change(self.update)

        self.update()

    # def set_datasource(self, ds_name):
    #     self._dsname = ds_name
    #
    #     self.update()

    def _update(self, *args, **kwargs):
        cdata = self._get_cdata()
        self.clim = [float(cdata.min()), float(cdata.max())]
        #self.update(*args, **kwargs)

    def update(self, *args, **kwargs):
        print('lw update')
        if not (self.engine is None or self.datasource is None):
            self.engine.update_from_datasource(self.datasource,
                                               getattr(cm, self.cmap),
                                               self.clim, self.alpha)
            self.on_update.send(self)

    def render(self, gl_canvas):
        if self.visible:
            self.engine.render(gl_canvas)

    def _get_cdata(self):
        try:
            cdata = self.datasource[self.engine.vertexColour]
        except KeyError:
            cdata = np.array([0, 1])

        return cdata

    @property
    def default_view(self):
        from traitsui.api import View, Item, Group, InstanceEditor, EnumEditor
        from PYME.ui.custom_traits_editors import HistLimitsEditor, CBEditor

        return View(
            [
                Group([
                    Item('dsname',
                         label='Data',
                         editor=EnumEditor(values=self.data_source_names)),
                ]),
                Item('method'),
                #Item('_'),
                Group([
                    Item('engine',
                         style='custom',
                         show_label=False,
                         editor=InstanceEditor(
                             view=self.engine.view(self.datasource.keys()))),
                ]),
                #Item('engine.color_key', editor=CBEditor(choices=self.datasource.keys())),
                Group([
                    Item('clim',
                         editor=HistLimitsEditor(data=self._get_cdata),
                         show_label=False),
                ]),
                Group([
                    Item('cmap', label='LUT'),
                    Item('alpha'),
                    Item('visible')
                ],
                      orientation='horizontal',
                      layout='flow')
            ], )
        #buttons=['OK', 'Cancel'])

    def default_traits_view(self):
        return self.default_view
Beispiel #18
0
class CalculateFRCBase(ModuleBase):
    """
    Base class. Refer to derived classes for docstrings.
    """
    pre_filter = Enum(['Tukey_1/8', None])
    frc_smoothing_func = Enum(['Cubic Spline', 'Sigmoid', None])
    multiprocessing =  Bool(True)
#    plot_graphs = Bool()
    cubic_smoothing  = Float(0.01)
    
    save_path = File()
    
#    output_fft_image_a = Output('FRC_fft_image_a')
#    output_fft_image_b = Output('FRC_fft_image_b')
    output_fft_images_cc = Output('FRC_fft_images_cc')
    output_frc_dict = Output('FRC_dict')
    output_frc_plot = Output('FRC_plot')
    output_frc_raw = Output('FRC_raw')
    
    def execute(self):
        raise Exception("Base class not fully implemented")
    
    def preprocess_images(self, image_pair):
        # pad images to square shape
#        image_pair = self.pad_images_to_equal_dims(image_pair)
        
        dims_length = np.stack([im.shape for im in image_pair], 0)
        assert np.all([np.all(dims_length[:, i] == dims_length[0, i])for i in range(dims_length.shape[1])]), "Images not the same dimension."
        
        # apply filtering to zero near edge of images
        if self.pre_filter == 'Tukey_1/8':
            image_pair = self.filter_images_tukey(image_pair, 1./8)
        elif self.pre_filter == None:
            pass
        else:
            raise Exception()
            
        return image_pair
        
#    def pad_images_to_equal_dims(self, images):
#        dims_length = np.stack([im.shape for im in images], 0)
#        assert np.all([np.all(dims_length[:, i] == dims_length[0, i])for i in xrange(dims_length.shape[1])]), "Images not the same dimension."
#        
#        return images
#        
#        dims_length = dims_length[0, :]        
#        max_dim = dims_length.max()
#        
#        padding = np.empty((dims_length.shape[0], 2), dtype=np.int)
#        for dim in xrange(dims_length.shape[0]):
#            total_padding = max_dim - dims_length[dim]
#            padding[dim] = [total_padding // 2, total_padding - total_padding //2]
#            
#        results = list()
#        for im in images:
#            results.append(np.pad(im, padding, mode='constant', constant_values=0))
#        
#        return results

    def filter_images_tukey(self, images, alpha):
        from scipy.signal import tukey
        
#        window = tukey(images[0].shape[0], alpha=alpha)
#        window_nd = np.prod(np.stack(np.meshgrid(*(window,)*images[0].ndim)), axis=0)
        windows = [tukey(images[0].shape[i], alpha=alpha) for i in range(images[0].ndim)]
        window_nd = np.prod(np.stack(np.meshgrid(*windows, indexing='ij')), axis=0)
        
#        for im in images:
#            im *= window_nd
        
        return [images[0]*window_nd, images[1]*window_nd]

    def calculate_FRC_from_images(self, image_pair, mdh):
        ft_images = list()
        if self.multiprocessing:
            results = list()
            for im in image_pair:
                results.append(self._pool.apply_async(np.fft.fftn, (im,)))            
            for res in results:
                ft_images.append(res.get())
            del results
        else:
            for im in image_pair:
                ft_images.append(np.fft.fftn(im))
        
#        im_fft_freq = np.fft.fftfreq(image_pair[0].shape[0], self._pixel_size_in_nm)
#        im_R = np.sqrt(im_fft_freq[:, None]**2 + im_fft_freq[None, :]**2)
        im_fft_freqs = [np.fft.fftfreq(image_pair[0].shape[i], self._pixel_size_in_nm[i]) for i in range(image_pair[0].ndim)]
        im_R = np.linalg.norm(np.stack(np.meshgrid(*im_fft_freqs, indexing='ij')), axis=0)

        im1_fft_power = np.multiply(ft_images[0], np.conj(ft_images[0]))
        im2_fft_power = np.multiply(ft_images[1], np.conj(ft_images[1]))        
        im12_fft_power = np.multiply(ft_images[0], np.conj(ft_images[1]))                
        
##        fft_ims = ImageStack(data=np.stack([np.fft.fftshift(im1_fft_power),
##                                            np.fft.fftshift(im2_fft_power),
##                                            np.fft.fftshift(im12_fft_power)], axis=-1), mdh=mdh)
##        self._namespace[self.output_fft_images] = fft_ims
#        self._namespace[self.output_fft_image_a] = ImageStack(data=np.fft.fftshift(im1_fft_power), titleStub="ImageA_FFT")
#        self._namespace[self.output_fft_image_b] = ImageStack(data=np.fft.fftshift(im2_fft_power), titleStub="ImageB_FFT")
#        self._namespace[self.output_fft_images_cc] = ImageStack(data=np.fft.fftshift(im12_fft_power), titleStub="ImageA_Image_B_FFT_CC")
        
        try:
            self._namespace[self.output_fft_images_cc] = ImageStack(data=np.stack([np.atleast_3d(np.fft.fftshift(im1_fft_power)),
                   np.atleast_3d(np.fft.fftshift(im2_fft_power)),
                   np.atleast_3d(np.fft.fftshift(im12_fft_power))], 3), titleStub="ImageA_Image_FFT_CC")
            
#            if self.plot_graphs:
#                from PYME.DSView.dsviewer import ViewIm3D, View3D
#    #            ViewIm3D(self._namespace[self.output_fft_image_a])
#    #            ViewIm3D(self._namespace[self.output_fft_image_b])
#                ViewIm3D(self._namespace[self.output_fft_images_cc])
#    #            View3D(np.fft.fftshift(im_R))
            
        except Exception as e:
            print (e)
            
        
        im1_fft_flat_res = CalculateFRCBase.BinData(im_R.flatten(), im1_fft_power.flatten(), statistic='mean', bins=201)
        im2_fft_flat_res = CalculateFRCBase.BinData(im_R.flatten(), im2_fft_power.flatten(), statistic='mean', bins=201)
        im12_fft_flat_res = CalculateFRCBase.BinData(im_R.flatten(), im12_fft_power.flatten(), statistic='mean', bins=201)
        
        corr = np.real(im12_fft_flat_res.statistic) / np.sqrt(np.abs(im1_fft_flat_res.statistic*im2_fft_flat_res.statistic))
        
        smoothed_frc = self.smooth_frc(im12_fft_flat_res.bin_edges[:-1], corr, self.cubic_smoothing)
        
        res, rawdata = self.calculate_threshold(im12_fft_flat_res.bin_edges[:-1], corr, smoothed_frc, im12_fft_flat_res.counts)
        
        return res, rawdata
    
    def smooth_frc(self, freq, corr, cubic_smoothing):
        if self.frc_smoothing_func is None:
            interp_frc = interpolate.interp1d(freq, corr, kind='next', )
            return interp_frc
        
        elif self.frc_smoothing_func == "Sigmoid":
            func = CalculateFRCBase.Sigmoid
            fit_res = optimize.minimize(lambda a, x: np.sum(np.square(func(a[0], a[1], x)-corr)), [1, freq[len(freq)/2]], args=(freq), method='Nelder-Mead')

            return partial(func, n=fit_res.x[0], c=fit_res.x[1])
        elif self.frc_smoothing_func == "Cubic Spline":
            # smoothed so that average deviation loss is less than 0.2% of original. Somewhat arbitrary but probably not totally unreasonable since FRC is bounded 0 to 1.
#            interp_frc = interpolate.UnivariateSpline(freq, corr, k=3, s=len(freq)*(0.002*np.var(corr)))
#            interp_frc = interpolate.UnivariateSpline(freq, corr, k=3, s=(0.05*np.std(corr)))
            interp_frc = interpolate.UnivariateSpline(freq, corr, k=3, s=cubic_smoothing)

            return interp_frc            
    
    def calculate_threshold(self, freq, corr, corr_func, counts):
        res = dict()
        
        fsc_0143 = optimize.minimize(lambda x: np.square(corr_func(x=x)-0.143), freq[np.argmax(corr_func(x=freq)-0.143 < 0)], method='Nelder-Mead')
        res['frc 1/7'] = 1./fsc_0143.x[0]
        
        sigma = 1.0 / np.sqrt(counts*0.5)
        sigma_spl = interpolate.UnivariateSpline(freq, sigma, k=3, s=0)
        fsc_3sigma = optimize.minimize(lambda x: np.square(corr_func(x=x)-3.*sigma_spl(x)), freq[np.argmax(corr_func(x=freq)-3.*sigma_spl(freq) < 0)], method='Nelder-Mead')
        res['frc 3 sigma'] = 1./fsc_3sigma.x[0]        
        
        # van Heel and Schatz, 2005, Fourier shell correlation threshold criteria    
        half_bit = (0.2071 + 1.9102 / np.sqrt(counts)) / (1.2071 + 0.9102 / np.sqrt(counts))
        half_bit_spl = interpolate.UnivariateSpline(freq, half_bit, k=3, s=0)
        fsc_half_bit = optimize.minimize(lambda x: np.square(corr_func(x=x)-half_bit_spl(x)), freq[np.argmax(corr_func(x=freq)-half_bit_spl(freq) < 0)], method='Nelder-Mead')
        res['frc half bit'] = 1./fsc_half_bit.x[0]
        
#        fsc_max = np.max([fsc_0143.x[0], fsc_2sigma.x[0], fsc_3sigma.x[0], fsc_5sigma.x[0], fsc_half_bit.x[0]])
#        axes[1].set_xlim(0, np.min([2*fsc_max, im12_fft_flat_res.bin_edges[-1]]))
        
#        if not self.plot_graphs:
#            ioff()
        
        def plot():
            frc_text = ""
            fig, axes = subplots(1,2,figsize=(10,4))
            axes[0].plot(freq, corr)
            axes[0].plot(freq, corr_func(x=freq))
            
            axes[0].axhline(0.143, ls='--', color='red')
            axes[0].axvline(fsc_0143.x[0], ls='--', color='red', label='1/7')
            frc_text += "\n1/7:  {:.2f} nm".format(1./fsc_0143.x[0])
     
            axes[0].plot(freq, 3*sigma_spl(freq), ls='--', color='pink')
            axes[0].axvline(fsc_3sigma.x[0], ls='--', color='pink', label='3 sigma')
            frc_text += "\n3 sigma:  {:.2f} nm".format(1./fsc_3sigma.x[0])
     
            axes[0].plot(freq, half_bit_spl(freq), ls='--', color='purple')
            axes[0].axvline(fsc_half_bit.x[0], ls='--', color='purple', label='1/2 bit')
    
            frc_text += "\n1/2 bit:  {:.2f} nm".format(1./fsc_half_bit.x[0])
            
            axes[0].legend()            
    #        axes[0].set_ylim(None, 1.1)
            
            x_ticklocs = axes[0].get_xticks()
            axes[0].set_xticklabels(["{:.1f}".format(1./i) for i in x_ticklocs])
            axes[0].set_ylabel("FSC/FRC")
            axes[0].set_xlabel("Resol (nm)")
            
            axes[1].text(0.5, 0.5, frc_text, horizontalalignment='center', verticalalignment='center', transform=axes[1].transAxes)
            axes[1].set_axis_off()
        
#        if self.plot_graphs:
#            fig.show()
#        else:
#            ion()
            
        plot()
        
        self._namespace[self.output_frc_plot] = Plot(plot)
        
        rawdata = {'freq':freq, 'corr':corr, 'smooth':corr_func(x=freq), '1/7':np.ones_like(freq)/7, '3 sigma':3*sigma_spl(freq), '1/2 bit':half_bit_spl(freq)}
        
        return res, rawdata
    
    def save_to_file(self, namespace):
        if self.save_path is not "":
            try:
                np.savez_compressed(self.save_path, raw=namespace[self.output_frc_raw], results=namespace[self.output_frc_dict])
            except Exception as e:
                raise e
    
    @staticmethod
    def BinData(indexes, data, statistic='mean', bins=10):
        # Calculates binned statistics. Supports complex number.
        if statistic == 'mean':
            func = np.mean
        elif statistic == 'sum':
            func = np.sum
        
        class Result(object):
            statistic = None
            bin_edges = None
            counts = None
        
        bins = np.linspace(indexes.min(), indexes.max(), bins)
        binned = np.zeros(len(bins)-1, dtype=data.dtype)
        counts = np.zeros(len(bins)-1, dtype=np.int)
        
        indexes_sort_arg = np.argsort(indexes.flatten())
        indexes_sorted = indexes.flatten()[indexes_sort_arg]
        data_sorted = data.flatten()[indexes_sort_arg]
        edges_indexes = np.searchsorted(indexes_sorted, bins)
        
        for i in range(bins.shape[0]-1):
            values = data_sorted[edges_indexes[i]:edges_indexes[i+1]]
            binned[i] = func(values)
            counts[i] = len(values)
            
        res = Result()
        res.statistic = binned
        res.bin_edges = bins
        res.counts = counts
        return res
    
    @staticmethod
    def Sigmoid(n, c, x):
        res = 1 - 1 / (1 + np.exp(n*(-x+c)))
        return res