示例#1
0
文件: gabor_1d.py 项目: fogoke/blusky
class Gabor1D(HasStrictTraits):
    """
    Construct a 1-D Morlet wavelet using parameters of center frequency,
    bandwidth and sample rate. Defining a difference in bandwidth will
    result in an eccentricity of the wavelet.


    The kernel method generates the wavelet, if you want to define a
    specific length for the wavelet use the length optional argument.
    Otherwise it will use defaults determined by "crop".

    Optional keyword arguments:
    crop - Float
       Specifies a multiple of the envelope to crop the image for output.

    taper - Bool
       If true, applies a hanning window to the image on output. This maybe
       useful for reducing edge effects in subsequent convolutions.

    """

    #: If the wavelet with eccentricity, the orientation of
    # this defines the orientation of its principle axis.
    orientation = Float

    #: bandwidth measured as full-width a half maximum (-3db) of the
    # gaussian envelope in the frequency domain.
    # The two numbers define fwhm in orthogonal directions.
    bandwidth = Tuple(Float)

    #: The center frequency along the principle axis
    center_frequency = Float

    #: sample rate in units
    sample_rate = Float

    #: Where (in std deviations) of the envelope in the spatial domain,
    # to crop the wavelet. Should be >3 (generally) smaller values
    # will be more efficient to convolve, but have less fidelity in
    # the frequency domain.
    crop = Float(6.0)

    #: Optionally taper the cropped image.
    taper = Bool(False)

    #: To build a convolutional model, trade-off fidelity with
    # computation cost (small the better).
    shape = Property(Tuple(Int), depends_on=["_sigma"])

    #: (Optional) labels scale of wavelet, makes sense in a filter bank.
    scale = Int(-1)

    #: measured in "samples"
    _sigma = Property(
        Tuple(Float),
        depends_on=["bandwidth", "center_frequency", "sample_rate"],
    )

    def __init__(self, center_frequency, bandwidth, sample_rate, **traits):
        """
        Parameters
        ----------
        center_frequency - Float
           Specify the center frequency (cycles/units) measured along the
           minor axis of the wavelet.

        bandwidth - Tuple(Float, Float)
           Specify the bandwidth in the major/minor axes (cycles/units)
           of the wavelet.

        sample_rate - Float
           Specify units/sample.

        Optional keyword argument:

        crop - Float
           Specifies a multiple of the envelope to crop the image for output.

        taper - Bool
           If true, applies a hanning window to the image on output. This
           maybe useful for reducing edge effects in subsequent convolutions.

        usage:

        wav = Morlet1D(sample_rate=0.004,
                       center_frequency=60.,
                       bandwidth=(30.,15.))
        """

        self.center_frequency = center_frequency
        self.bandwidth = bandwidth
        self.sample_rate = sample_rate

        super().__init__(**traits)

    def _get__sigma(self):
        """
        sigma parameterizes the gaussian envelope of the wavelet.
        Measure the bandwidth at the FWHM, the fouier spectrum is
        gaussian with standard deviation: sigma' = 1/sigma
        bandwidth, measured at FWHM ~ 2.355 / sigma
        """
        def to_ang(f):
            return 2 * np.pi * f * self.sample_rate

        return tuple([2.355 / to_ang(f) for f in self.bandwidth])

    def _get_shape(self):
        """
        Define a square large enough to hold the wavelet in
        any orientation.
        """
        # tiles are square
        _n = np.int_(self.crop * max(self._sigma))
        # nicer if odd
        _n += 1 - (_n % 2)
        return (_n, )

    def _taper(self):
        """ Compute hanning window to taper image.
        """
        taper = np.outer(np.kaiser(self.shape[0], 3),
                         np.kaiser(self.shape[1], 3))
        return taper

    def kernel(self, shape=None):
        """
        Output the wavelet in an complex valued array.

        Derivative of the work: morlet_2d_pyramid.m, we applied the
        same idea to 1-d for consistency.

        from https://github.com/scatnet/scatnet

        See license in NOTICE.txt in this directory.

        Parameters
        ----------
        length - Int (optional)
           Provide a required length for the output wavelet. If not provided
           it will use defaults determined by crop.

        Return
        ------
        wavelet - Array
           A 1d array containing the wavelet
        """

        if shape is None:
            N = self.shape[0]
        else:
            if isinstance(shape, int):
                N = shape
            else:
                N = shape[0]

        x = np.arange(N)
        x -= N // 2

        # convert to units of cycles per sample
        xi = 2 * np.pi * self.center_frequency * self.sample_rate

        gaussian_envelope = np.exp(-x * x / (2 * (self._sigma[0]**2)))
        gabc = gaussian_envelope * np.exp(1j * x * xi)

        normalized_wavelet = gabc / (np.abs(gabc).sum())

        return normalized_wavelet
示例#2
0
class WizardPage(MWizardPage, HasTraits):
    """ The toolkit specific implementation of a WizardPage.

    See the IWizardPage interface for the API documentation.

    """

    # 'IWizardPage' interface ---------------------------------------------#

    id = Str()

    next_id = Str()

    last_page = Bool(False)

    complete = Bool(False)

    heading = Str()

    subheading = Str()

    size = Tuple()

    # ------------------------------------------------------------------------
    # 'IWizardPage' interface.
    # ------------------------------------------------------------------------

    def create_page(self, parent):
        """ Creates the wizard page. """

        content = self._create_page_content(parent)

        # We allow some flexibility with the sort of control we are given.
        if not isinstance(content, QtGui.QWizardPage):
            wp = _WizardPage(self)

            if isinstance(content, QtGui.QLayout):
                wp.setLayout(content)
            else:
                assert isinstance(content, QtGui.QWidget)

                lay = QtGui.QVBoxLayout()
                lay.addWidget(content)

                wp.setLayout(lay)

            content = wp

        # Honour any requested page size.
        if self.size:
            width, height = self.size

            if width > 0:
                content.setMinimumWidth(width)

            if height > 0:
                content.setMinimumHeight(height)

        content.setTitle(self.heading)
        content.setSubTitle(self.subheading)

        return content

    # ------------------------------------------------------------------------
    # Protected 'IWizardPage' interface.
    # ------------------------------------------------------------------------

    def _create_page_content(self, parent):
        """ Creates the actual page content. """

        # Dummy implementation - override!
        control = QtGui.QWidget(parent)

        palette = control.palette()
        palette.setColor(QtGui.QPalette.Window, QtGui.QColor("yellow"))
        control.setPalette(palette)
        control.setAutoFillBackground(True)

        return control
class Canvas(Container):
    """
    An infinite canvas with components on it.  It can optionally be given
    a "view region" which will be used as the notional bounds of the
    canvas in all operations that require bounds.

    A Canvas can be nested inside another container, but usually a
    viewport is more appropriate.

    Note: A Canvas has infinite bounds, but its .bounds attribute is
    overloaded to be something more meaningful, namely, the bounding
    box of its child components and the optional view area of the
    viewport that is looking at it.  (TODO: add support for multiple
    viewports.)
    """

    # This optional tuple of (x,y,x2,y2) allows viewports to inform the canvas of
    # the "region of interest" that it should use when computing its notional
    # bounds for clipping and event handling purposes.  If this trait is None,
    # then the canvas really does behave as if it has no bounds.
    view_bounds = Trait(None, None, Tuple)

    # The (x,y) position of the lower-left corner of the rectangle corresponding
    # to the dimensions in self.bounds.  Unlike self.position, this position is
    # in the canvas's space, and not in the coordinate space of the parent.
    bounds_offset = List

    draw_axes = Bool(False)

    #------------------------------------------------------------------------
    # Inherited traits
    #------------------------------------------------------------------------

    # Use the auto-size/fit_components mechanism to ensure that the bounding
    # box around our inner components gets updated properly.
    auto_size = True
    fit_components = "hv"

    # The following traits are ignored, but we set them to sensible values.
    fit_window = False
    resizable = "hv"

    #------------------------------------------------------------------------
    # Protected traits
    #------------------------------------------------------------------------

    # The (x, y, x2, y2) coordinates of the bounding box of the components
    # in our inner coordinate space
    _bounding_box = Tuple((0, 0, 100, 100))

    def compact(self):
        """
        Wraps the superclass method to also take into account the view
        bounds (if they are present
        """
        self._bounding_box = self._calc_bounding_box()
        self._view_bounds_changed()

    def is_in(self, x, y):
        return True

    def remove(self, *components):
        """ Removes components from this container """
        needs_compact = False
        for component in components:
            if component in self._components:
                component.container = None
                self._components.remove(component)
            else:
                raise RuntimeError(
                    "Unable to remove component from container.")

            # Check to see if we need to compact.
            x, y, x2, y2 = self._bounding_box
            if (component.outer_x2 == x2-x) or \
                    (component.outer_y2 == y2-y) or \
                    (component.x == 0) or (component.y == 0):
                needs_compact = True

        if needs_compact:
            self.compact()
        self.invalidate_draw()

    def draw(self, gc, view_bounds=None, mode="normal"):
        if self.view_bounds is None:
            self.view_bounds = view_bounds
        super(Canvas, self).draw(gc, view_bounds, mode)

    #------------------------------------------------------------------------
    # Protected methods
    #------------------------------------------------------------------------

    def _should_compact(self):
        if self.auto_size:
            if self.view_bounds is not None:
                llx, lly = self.view_bounds[:2]
            else:
                llx = lly = 0
            for component in self.components:
                if (component.outer_x2 >= self.width) or \
                   (component.outer_y2 >= self.height) or \
                   (component.outer_x < llx) or (component.outer_y < lly):
                    return True
        else:
            return False

    def _draw_background(self, gc, view_bounds=None, mode="default"):
        if self.bgcolor not in ("clear", "transparent", "none"):
            if self.view_bounds is not None:
                x, y, x2, y2 = self.view_bounds
            else:
                x, y, x2, y2 = self._bounding_box
            r = (x, y, x2 - x + 1, y2 - y + 1)

            with gc:
                gc.set_antialias(False)
                gc.set_fill_color(self.bgcolor_)
                gc.draw_rect(r, FILL)

        # Call the enable _draw_border routine
        if not self.overlay_border and self.border_visible:
            # Tell _draw_border to ignore the self.overlay_border
            self._draw_border(gc, view_bounds, mode, force_draw=True)
        return

    def _draw_underlay(self, gc, view_bounds=None, mode="default"):
        if self.draw_axes:
            x, y, x2, y2 = self.view_bounds
            if (x <= 0 <= x2) or (y <= 0 <= y2):
                with gc:
                    gc.set_stroke_color((0, 0, 0, 1))
                    gc.set_line_width(1.0)
                    gc.move_to(0, y)
                    gc.line_to(0, y2)
                    gc.move_to(x, 0)
                    gc.line_to(x2, 0)
                    gc.stroke_path()
        super(Container, self)._draw_underlay(gc, view_bounds, mode)

    def _transform_view_bounds(self, view_bounds):
        # Overload the parent class's implementation to skip visibility test
        if view_bounds:
            v = view_bounds
            new_bounds = (v[0] - self.x, v[1] - self.y, v[2], v[3])
        else:
            new_bounds = None
        return new_bounds

    #------------------------------------------------------------------------
    # Event handlers
    #------------------------------------------------------------------------

    def _bounds_offset_default(self):
        return [0, 0]

    def _view_bounds_changed(self):
        llx, lly, urx, ury = self._bounding_box
        if self.view_bounds is not None:
            x, y, x2, y2 = self.view_bounds
            llx = min(llx, x)
            lly = min(lly, y)
            urx = max(urx, x2)
            ury = max(ury, y2)
        self.bounds_offset = [llx, lly]
        self.bounds = [urx - llx + 1, ury - lly + 1]

    # Override Container.bounds_changed so that _layout_needed is not
    # set.  Containers need to invalidate layout because they act as
    # sizers, but the Canvas is unbounded and thus does not need to
    # invalidate layout.
    def _bounds_changed(self, old, new):
        Component._bounds_changed(self, old, new)
        self.invalidate_draw()

    def _bounds_items_changed(self, event):
        Component._bounds_items_changed(self, event)
        self.invalidate_draw()
class Window(MWindow, Widget):
    """ The toolkit specific implementation of a Window.  See the IWindow
    interface for the API documentation.
    """

    #### 'IWindow' interface ##################################################

    position = Property(Tuple)

    size = Property(Tuple)

    title = Str  #Unicode

    #### Events #####

    activated = Event

    closed = Event

    closing = Event

    deactivated = Event

    key_pressed = Event(KeyPressedEvent)

    opened = Event

    opening = Event

    #### Private interface ####################################################

    # Shadow trait for position.
    _position = Tuple((-1, -1))

    # Shadow trait for size.
    _size = Tuple((-1, -1))

    ###########################################################################
    # 'IWindow' interface.
    ###########################################################################

    def activate(self):
        self.control.Iconize(False)
        self.control.Raise()

    def show(self, visible):
        self.control.Show(visible)

    ###########################################################################
    # Protected 'IWindow' interface.
    ###########################################################################

    def _add_event_listeners(self):
        self.control.Bind(wx.EVT_ACTIVATE, self._wx_on_activate)
        self.control.Bind(wx.EVT_CLOSE, self._wx_on_close)
        self.control.Bind(wx.EVT_SIZE, self._wx_on_control_size)
        self.control.Bind(wx.EVT_MOVE, self._wx_on_control_move)
        self.control.Bind(wx.EVT_CHAR, self._wx_on_char)

    ###########################################################################
    # Protected 'IWidget' interface.
    ###########################################################################

    def _create_control(self, parent):
        # create a basic window control

        style = wx.DEFAULT_FRAME_STYLE \
                | wx.FRAME_NO_WINDOW_MENU \
                | wx.CLIP_CHILDREN

        control = wx.Frame(parent,
                           -1,
                           self.title,
                           style=style,
                           size=self.size,
                           pos=self.position)

        control.SetBackgroundColour(SystemMetrics().dialog_background_color)

        return control

    ###########################################################################
    # Private interface.
    ###########################################################################

    def _get_position(self):
        """ Property getter for position. """

        return self._position

    def _set_position(self, position):
        """ Property setter for position. """
        if self.control is not None:
            self.control.SetPosition(position)

        old = self._position
        self._position = position

        self.trait_property_changed('position', old, position)

    def _get_size(self):
        """ Property getter for size. """

        return self._size

    def _set_size(self, size):
        """ Property setter for size. """

        if self.control is not None:
            self.control.SetSize(size)

        old = self._size
        self._size = size

        self.trait_property_changed('size', old, size)

    def _title_changed(self, title):
        """ Static trait change handler. """

        if self.control is not None:
            self.control.SetTitle(title)

    #### wx event handlers ####################################################

    def _wx_on_activate(self, event):
        """ Called when the frame is being activated or deactivated. """

        if event.GetActive():
            self.activated = self
        else:
            self.deactivated = self

        event.Skip()

    def _wx_on_close(self, event):
        """ Called when the frame is being closed. """

        self.close()

    def _wx_on_control_move(self, event):
        """ Called when the window is resized. """

        # Get the real position and set the trait without performing
        # notification.

        # WXBUG - From the API documentation you would think that you could
        # call event.GetPosition directly, but that would be wrong.  The pixel
        # reported by that call is the pixel just below the window menu and
        # just right of the Windows-drawn border.

        try:
            self._position = event.GetEventObject().GetPosition().Get(
            )  #Sizer.GetPosition().Get()
        except:
            pass
        event.Skip()

    def _wx_on_control_size(self, event):
        """ Called when the window is resized. """

        # Get the new size and set the shadow trait without performing
        # notification.

        wxsize = event.GetSize()

        self._size = (wxsize.GetWidth(), wxsize.GetHeight())

        event.Skip()

    def _wx_on_char(self, event):
        """ Called when a key is pressed when the tree has focus. """

        self.key_pressed = KeyPressedEvent(
            alt_down=event.AltDown() == 1,
            control_down=event.ControlDown() == 1,
            shift_down=event.ShiftDown() == 1,
            key_code=event.GetKeyCode(),
            event=event)

        event.Skip()
示例#5
0
class PairList(ValueNode):
    value = List(Either(Name, Tuple(Name, Node)))
示例#6
0
class BleedthroughLinearOp(HasStrictTraits):
    """
    Apply matrix-based bleedthrough correction to a set of fluorescence channels.
    
    This is a traditional matrix-based compensation for bleedthrough.  For each
    pair of channels, the user specifies the proportion of the first channel
    that bleeds through into the second; then, the module performs a matrix
    multiplication to compensate the raw data.
    
    The module can also estimate the bleedthrough matrix using one
    single-color control per channel.
    
    This works best on data that has had autofluorescence removed first;
    if that is the case, then the autofluorescence will be subtracted from
    the single-color controls too.
    
    To use, set up the :attr:`controls` dict with the single color controls;
    call :meth:`estimate` to parameterize the operation; check that the bleedthrough 
    plots look good by calling :meth:`~.BleedthroughLinearDiagnostic.plot` on 
    the :class:`BleedthroughLinearDiagnostic` instance returned by 
    :meth:`default_view`; and then :meth:`apply` on an :class:`.Experiment`.
    
    Attributes
    ----------
    controls : Dict(Str, File)
        The channel names to correct, and corresponding single-color control
        FCS files to estimate the correction splines with.  Must be set to
        use :meth:`estimate`.
        
    spillover : Dict(Tuple(Str, Str), Float)
        The spillover "matrix" to use to correct the data.  The keys are pairs
        of channels, and the values are proportions of spectral overlap.  If 
        ``("channel1", "channel2")`` is present as a key, 
        ``("channel2", "channel1")`` must also be present.  The module does not
        assume that the matrix is symmetric.
        
    control_conditions : Dict(Str, Dict(Str, Any))
        Occasionally, you'll need to specify the experimental conditions that
        the bleedthrough tubes were collected under (to apply the operations in the 
        history.)  Specify them here.  The key is the channel name; they value
        is a dictionary of the conditions (same as you would specify for a
        :class:`~.Tube` )

    Examples
    --------

    Create a small experiment:
    
    .. plot::
        :context: close-figs
    
        >>> import cytoflow as flow
        >>> import_op = flow.ImportOp()
        >>> import_op.tubes = [flow.Tube(file = "tasbe/rby.fcs")]
        >>> ex = import_op.apply()

    Correct for autofluorescence
    
    .. plot::
        :context: close-figs
        
        >>> af_op = flow.AutofluorescenceOp()
        >>> af_op.channels = ["Pacific Blue-A", "FITC-A", "PE-Tx-Red-YG-A"]
        >>> af_op.blank_file = "tasbe/blank.fcs"
        
        >>> af_op.estimate(ex)
        >>> af_op.default_view().plot(ex)  

        >>> ex2 = af_op.apply(ex) 
    
    Create and parameterize the operation
    
    .. plot::
        :context: close-figs
        
        >>> bl_op = flow.BleedthroughLinearOp()
        >>> bl_op.controls = {'Pacific Blue-A' : 'tasbe/ebfp.fcs',
        ...                   'FITC-A' : 'tasbe/eyfp.fcs',
        ...                   'PE-Tx-Red-YG-A' : 'tasbe/mkate.fcs'}
    
    Estimate the model parameters
    
    .. plot::
        :context: close-figs 
    
        >>> bl_op.estimate(ex2)
    
    Plot the diagnostic plot
    
    .. note::
       The diagnostic plots look really bad in the online documentation.
       They're better in a real-world example, I promise!
    
    .. plot::
        :context: close-figs

        >>> bl_op.default_view().plot(ex2)  

    Apply the operation to the experiment
    
    .. plot::
        :context: close-figs
    
        >>> ex2 = bl_op.apply(ex2)  
    
    """

    # traits
    id = Constant('edu.mit.synbio.cytoflow.operations.bleedthrough_linear')
    friendly_id = Constant("Linear Bleedthrough Correction")

    name = Constant("Bleedthrough")

    controls = Dict(Str, File)
    spillover = Dict(Tuple(Str, Str), Float)
    control_conditions = Dict(Str, Dict(Str, Any), {})

    _sample = Dict(Str, Any, transient=True)

    def estimate(self, experiment, subset=None):
        """
        Estimate the bleedthrough from simgle-channel controls in :attr:`controls`
        """
        if experiment is None:
            raise util.CytoflowOpError('experiment', "No experiment specified")

        channels = list(self.controls.keys())

        if len(channels) < 2:
            raise util.CytoflowOpError(
                'channels',
                "Need at least two channels to correct bleedthrough.")

        # make sure the control files exist
        for channel in channels:
            if not os.path.isfile(self.controls[channel]):
                raise util.CytoflowOpError(
                    'channels', "Can't find file {0} for channel {1}.".format(
                        self.controls[channel], channel))

        self.spillover.clear()
        self._sample.clear()

        for channel in channels:

            # make a little Experiment
            check_tube(self.controls[channel], experiment)
            tube_conditions = self.control_conditions[
                channel] if channel in self.control_conditions else {}
            exp_conditions = {
                k: experiment.data[k].dtype.name
                for k in tube_conditions.keys()
            }

            tube_exp = ImportOp(
                tubes=[
                    Tube(file=self.controls[channel],
                         conditions=tube_conditions)
                ],
                conditions=exp_conditions,
                channels={
                    experiment.metadata[c]["fcs_name"]: c
                    for c in experiment.channels
                },
                name_metadata=experiment.metadata['name_metadata']).apply()

            # apply previous operations
            for op in experiment.history:
                if hasattr(op, 'by'):
                    for by in op.by:
                        if 'experiment' in experiment.metadata[by]:
                            raise util.CytoflowOpError(
                                'experiment',
                                "Prior to applying this operation, "
                                "you must not apply any operation with 'by' "
                                "set to an experimental condition.")
                tube_exp = op.apply(tube_exp)

            # subset it
            if subset:
                try:
                    tube_exp = tube_exp.query(subset)
                except Exception as exc:
                    raise util.CytoflowOpError(
                        'subset', "Subset string '{0}' isn't valid".format(
                            self.subset)) from exc

                if len(tube_exp.data) == 0:
                    raise util.CytoflowOpError(
                        'subset',
                        "Subset string '{0}' returned no events".format(
                            self.subset))

            tube_data = tube_exp.data

            # polyfit requires sorted data
            tube_data.sort_values(channel, inplace=True)

            # save a little of the data to plot later
            self._sample[channel] = tube_data.sample(n=1000)

            from_channel = channel

            # sometimes some of the data is off the edge of the
            # plot, and this screws up a linear regression

            from_min = np.min(tube_data[from_channel]) * 1.025
            from_max = np.max(tube_data[from_channel]) * 0.975
            tube_data[from_channel] = \
                tube_data[from_channel].clip(from_min, from_max)
            for to_channel in channels:

                if from_channel == to_channel:
                    continue

                to_min = np.min(tube_data[to_channel]) * 1.025
                to_max = np.max(tube_data[to_channel]) * 0.975
                tube_data[to_channel] = \
                    tube_data[to_channel].clip(to_min, to_max)

                tube_data.reset_index(drop=True, inplace=True)

                f = lambda x, k: x * k

                popt, _ = scipy.optimize.curve_fit(f, tube_data[from_channel],
                                                   tube_data[to_channel], 0)

                self.spillover[(from_channel, to_channel)] = popt[0]

    def apply(self, experiment):
        """Applies the bleedthrough correction to an experiment.
        
        Parameters
        ----------
        experiment : Experiment
            The experiment to which this operation is applied
            
        Returns
        -------
        Experiment
            A new :class:`Experiment` with the bleedthrough subtracted out.  
            The corrected channels have the following metadata added:
            
            - **linear_bleedthrough** : Dict(Str : Float)
              The values for spillover from other channels into this channel.
        
            - **bleedthrough_channels** : List(Str)
              The channels that were used to correct this one.
        
            - **bleedthrough_fn** : Callable (Tuple(Float) --> Float)
              The function that will correct one event in this channel.  Pass it
              the values specified in `bleedthrough_channels` and it will return
              the corrected value for this channel.
        """
        if experiment is None:
            raise util.CytoflowOpError('experiment', "No experiment specified")

        if not self.spillover:
            raise util.CytoflowOpError(
                'spillover', "Spillover matrix isn't set. "
                "Did you forget to run estimate()?")

        for (from_channel, to_channel) in self.spillover:
            if not from_channel in experiment.data:
                raise util.CytoflowOpError(
                    'spillover', "Can't find channel {0} in experiment".format(
                        from_channel))
            if not to_channel in experiment.data:
                raise util.CytoflowOpError(
                    'spillover',
                    "Can't find channel {0} in experiment".format(to_channel))

            if not (to_channel, from_channel) in self.spillover:
                raise util.CytoflowOpError(
                    'spillover', "Must have both (from, to) and "
                    "(to, from) keys in self.spillover")

        new_experiment = experiment.clone()

        # the completely arbitrary ordering of the channels
        channels = list(set([x for (x, _) in list(self.spillover.keys())]))

        # build the spillover matrix from the spillover dictionary
        a = [[self.spillover[(y, x)] if x != y else 1.0 for x in channels]
             for y in channels]

        # invert it.  use the pseudoinverse in case a is singular
        a_inv = np.linalg.pinv(a)

        # compute the corrected channels
        new_channels = np.dot(experiment.data[channels], a_inv)

        # and assign to the new experiment
        for i, c in enumerate(channels):
            new_experiment[c] = pd.Series(new_channels[:, i])

        for channel in channels:
            # add the spillover values to the channel's metadata
            new_experiment.metadata[channel]['linear_bleedthrough'] = \
                {x : self.spillover[(x, channel)]
                     for x in channels if x != channel}
            new_experiment.metadata[channel]['bleedthrough_channels'] = list(
                channels)
            new_experiment.metadata[channel][
                'bleedthrough_fn'] = lambda x, a_inv=a_inv: np.dot(x, a_inv)

        new_experiment.history.append(
            self.clone_traits(transient=lambda _: True))
        return new_experiment

    def default_view(self, **kwargs):
        """
        Returns a diagnostic plot to make sure spillover estimation is working.
        
        Returns
        -------
        IView
            An IView, call :meth:`~BleedthroughLinearDiagnostic.plot` to see the diagnostic plots
        """

        # the completely arbitrary ordering of the channels
        channels = list(set([x for (x, _) in list(self.spillover.keys())]))

        if set(self.controls.keys()) != set(channels):
            raise util.CytoflowOpError(
                'controls',
                "Must have both the controls and bleedthrough to plot")

        v = BleedthroughLinearDiagnostic(op=self)
        v.trait_set(**kwargs)
        return v
示例#7
0
class ToolPalette(Widget):

    tools = List

    id_tool_map = Dict

    tool_id_to_button_map = Dict

    button_size = Tuple((25, 25), Int, Int)

    is_realized = Bool(False)

    tool_listeners = Dict

    # Maps a button id to its tool id.
    button_tool_map = Dict

    ###########################################################################
    # 'object' interface.
    ###########################################################################

    def __init__(self, parent, **traits):
        """ Creates a new tool palette. """

        # Base class constructor.
        super(ToolPalette, self).__init__(**traits)

        # Create the toolkit-specific control that represents the widget.
        self.control = self._create_control(parent)

        return

    ###########################################################################
    # ToolPalette interface.
    ###########################################################################

    def add_tool(self, label, bmp, kind, tooltip, longtip):
        """ Add a tool with the specified properties to the palette.

        Return an id that can be used to reference this tool in the future.
        """

        return 1

    def toggle_tool(self, id, checked):
        """ Toggle the tool identified by 'id' to the 'checked' state.

        If the button is a toggle or radio button, the button will be checked
        if the 'checked' parameter is True; unchecked otherwise.  If the button
        is a standard button, this method is a NOP.
        """
        return

    def enable_tool(self, id, enabled):
        """ Enable or disable the tool identified by 'id'. """
        return

    def on_tool_event(self, id, callback):
        """ Register a callback for events on the tool identified by 'id'. """
        return

    def realize(self):
        """ Realize the control so that it can be displayed. """
        return

    def get_tool_state(self, id):
        """ Get the toggle state of the tool identified by 'id'. """
        state = 0

        return state

    ###########################################################################
    # Private interface.
    ###########################################################################

    def _create_control(self, parent):
        return None
示例#8
0
class Container(Component):
    """
    A Container is a logical container that holds other Components within it and
    provides an origin for Components to position themselves.  Containers can
    be "nested" (although "overlayed" is probably a better term).

    If auto_size is True, the container will automatically update its bounds to
    enclose all of the components handed to it, so that a container's bounds
    serve as abounding box (although not necessarily a minimal bounding box) of
    its contained components.
    """

    # The list of components within this frame
    components = Property  # List(Component)

    # Whether or not the container should automatically maximize itself to
    # fit inside the Window, if this is a top-level container.
    #
    # NOTE: the way that a Container determines that it's a top-level window is
    # that someone has explicitly set its .window attribute. If you need to do
    # this for some other reason, you may want to turn fit_window off.
    fit_window = Bool(True)

    # If true, the container get events before its children.  Otherwise, it
    # gets them afterwards.
    intercept_events = Bool(True)

    # Dimensions in which this container can resize to fit its components.
    # This trait only applies to dimensions that are also resizable; if the
    # container is not resizable in a certain dimension, then fit_components
    # has no effect.
    #
    # Also, note that the container does not *automatically* resize itself
    # based on the value of this trait.  Rather, this trait determines
    # what value is reported in get_preferred_size(); it is up to the parent
    # of this container to make sure that it is allocated the size that it
    # needs by setting its bounds appropriately.
    #
    # TODO: Merge resizable and this into a single trait?  Or have a separate
    # "fit" flag for each dimension in the **resizable** trait?
    # TODO: This trait is used in layout methods of various Container
    # subclasses in Chaco.  We need to move those containers into
    # Enable.
    fit_components = Enum("", "h", "v", "hv")

    # Whether or not the container should auto-size itself to fit all of its
    # components.
    # Note: This trait is still used, but will be eventually removed in favor
    # of **fit_components**.
    auto_size = Bool(False)

    # The default size of this container if it is empty.
    default_size = Tuple(0, 0)

    # The layers that the container will draw first, so that they appear
    # under the component layers of the same name.
    container_under_layers = Tuple("background", "image", "underlay",
                                   "mainlayer")

    #------------------------------------------------------------------------
    # Private traits
    #------------------------------------------------------------------------

    # Shadow trait for self.components
    _components = List  # List(Component)

    # Set of components that last handled a mouse event.  We keep track of
    # this so that we can generate mouse_enter and mouse_leave events of
    # our own.
    _prev_event_handlers = Instance(set, ())

    # This container can render itself in a different mode than what it asks of
    # its contained components.  This attribute stores the rendering mode that
    # this container requests of its children when it does a _draw(). If the
    # attribute is set to "default", then whatever mode is handed in to _draw()
    # is used.
    _children_draw_mode = Enum("default", "normal", "overlay", "interactive")

    #------------------------------------------------------------------------
    # Public methods
    #------------------------------------------------------------------------

    def __init__(self, *components, **traits):
        Component.__init__(self, **traits)
        for component in components:
            self.add(component)
        if "bounds" in list(traits.keys()) and "auto_size" not in list(
                traits.keys()):
            self.auto_size = False

        if 'intercept_events' in traits:
            warnings.warn("'intercept_events' is a deprecated trait",
                          warnings.DeprecationWarning)
        return

    def add(self, *components):
        """ Adds components to this container """
        for component in components:
            if component.container is not None:
                component.container.remove(component)
            component.container = self
        self._components.extend(components)

        # Expand our bounds if necessary
        if self._should_compact():
            self.compact()

        self.invalidate_draw()

    def remove(self, *components):
        """ Removes components from this container """
        for component in components:
            if component in self._components:
                component.container = None
                self._components.remove(component)
            else:
                raise RuntimeError(
                    "Unable to remove component from container.")

            # Check to see if we need to compact.
            if self.auto_size:
                if (component.outer_x2 == self.width) or \
                        (component.outer_y2 == self.height) or \
                        (component.x == 0) or (component.y == 0):
                    self.compact()

        self.invalidate_draw()

    def insert(self, index, component):
        "Inserts a component into a specific position in the components list"
        if component.container is not None:
            component.container.remove(component)
        component.container = self
        self._components.insert(index, component)

        self.invalidate_draw()

    def components_at(self, x, y):
        """
        Returns a list of the components underneath the given point (given in
        the parent coordinate frame of this container).
        """
        result = []
        if self.is_in(x, y):
            xprime = x - self.position[0]
            yprime = y - self.position[1]
            for component in self._components[::-1]:
                if component.is_in(xprime, yprime):
                    result.append(component)
        return result

    def raise_component(self, component):
        """ Raises the indicated component to the top of the Z-order """
        c = self._components
        ndx = c.index(component)
        if len(c) > 1 and ndx != len(c) - 1:
            self._components = c[:ndx] + c[ndx + 1:] + [component]
        return

    def lower_component(self, component):
        """ Puts the indicated component to the very bottom of the Z-order """
        raise NotImplementedError

    def cleanup(self, window):
        """When a window viewing or containing a component is destroyed,
        cleanup is called on the component to give it the opportunity to
        delete any transient state it may have (such as backbuffers)."""
        if self._components:
            for component in self._components:
                component.cleanup(window)
        return

    def compact(self):
        """
        Causes this container to update its bounds to be a compact bounding
        box of its components.  This may cause the container to recalculate
        and adjust its position relative to its parent container (and adjust
        the positions of all of its contained components accordingly).
        """
        # Loop over our components and determine the bounding box of all of
        # the components.
        ll_x, ll_y, ur_x, ur_y = self._calc_bounding_box()
        if len(self._components) > 0:
            # Update our position and the positions of all of our components,
            # but do it quietly
            for component in self._components:
                component.set(
                    position=[component.x - ll_x, component.y - ll_y],
                    trait_change_notify=False)

            # Change our position (in our parent's coordinate frame) and
            # update our bounds
            self.position = [self.x + ll_x, self.y + ll_y]

        self.bounds = [ur_x - ll_x, ur_y - ll_y]
        return

    #------------------------------------------------------------------------
    # Protected methods
    #------------------------------------------------------------------------

    def _calc_bounding_box(self):
        """
        Returns a 4-tuple (x,y,x2,y2) of the bounding box of all our contained
        components.  Expressed as coordinates in our local coordinate frame.
        """
        if len(self._components) == 0:
            return (0.0, 0.0, 0.0, 0.0)
        else:
            comp = self._components[0]
            ll_x = comp.outer_x
            ll_y = comp.outer_y
            ur_x = comp.outer_x2
            ur_y = comp.outer_y2

        for component in self._components[1:]:
            if component.x < ll_x:
                ll_x = component.x
            if component.x2 > ur_x:
                ur_x = component.x2
            if component.y < ll_y:
                ll_y = component.y
            if component.y2 > ur_y:
                ur_y = component.y2
        return (ll_x, ll_y, ur_x, ur_y)

    def _dispatch_draw(self, layer, gc, view_bounds, mode):
        """ Renders the named *layer* of this component.
        """
        new_bounds = self._transform_view_bounds(view_bounds)
        if new_bounds == empty_rectangle:
            return

        if self.layout_needed:
            self.do_layout()

        # Give the container a chance to draw first for the layers that are
        # considered "under" or "at" the main layer level
        if layer in self.container_under_layers:
            my_handler = getattr(self, "_draw_container_" + layer, None)
            if my_handler:
                my_handler(gc, view_bounds, mode)

        # Now transform coordinates and draw the children
        visible_components = self._get_visible_components(new_bounds)
        if visible_components:
            with gc:
                gc.translate_ctm(*self.position)
                for component in visible_components:
                    if component.unified_draw:
                        # Plot containers that want unified_draw only get
                        # called if their draw_layer matches the current layer
                        # we're rendering
                        if component.draw_layer == layer:
                            component._draw(gc, new_bounds, mode)
                    else:
                        component._dispatch_draw(layer, gc, new_bounds, mode)

        # The container's annotation and overlay layers draw over those of
        # its components.
        # FIXME: This needs to be abstracted so that when subclasses override
        # the draw_order list, these are pulled from the subclass list instead
        # of hardcoded here.
        if layer in ("annotation", "overlay", "border"):
            my_handler = getattr(self, "_draw_container_" + layer, None)
            if my_handler:
                my_handler(gc, view_bounds, mode)

        return

    def _draw_container(self, gc, mode="default"):
        "Draw the container background in a specified graphics context"
        pass

    def _draw_container_background(self, gc, view_bounds=None, mode="normal"):
        self._draw_background(gc, view_bounds, mode)

    def _draw_container_overlay(self, gc, view_bounds=None, mode="normal"):
        self._draw_overlay(gc, view_bounds, mode)

    def _draw_container_underlay(self, gc, view_bounds=None, mode="normal"):
        self._draw_underlay(gc, view_bounds, mode)

    def _draw_container_border(self, gc, view_bounds=None, mode="normal"):
        self._draw_border(gc, view_bounds, mode)

    def _get_visible_components(self, bounds):
        """ Returns a list of this plot's children that are in the bounds. """
        if bounds is None:
            return [c for c in self.components if c.visible]

        visible_components = []
        for component in self.components:
            if not component.visible:
                continue
            tmp = intersect_bounds(
                component.outer_position + component.outer_bounds, bounds)
            if tmp != empty_rectangle:
                visible_components.append(component)
        return visible_components

    def _should_layout(self, component):
        """ Returns True if it is appropriate for the container to lay out
        the component; False if not.
        """
        if not component or \
            (not component.visible and not component.invisible_layout):
            return False
        else:
            return True

    def _should_compact(self):
        """ Returns True if the container needs to call compact().  Subclasses
        can overload this method as needed.
        """
        if self.auto_size:
            width = self.width
            height = self.height
            for component in self.components:
                x, y = component.outer_position
                x2 = component.outer_x2
                y2 = component.outer_y2
                if (x2 >= width) or (y2 >= height) or (x < 0) or (y < 0):
                    return True
        else:
            return False

    def _transform_view_bounds(self, view_bounds):
        """
        Transforms the given view bounds into our local space and computes a new
        region that can be handed off to our children.  Returns a 4-tuple of
        the new position+bounds, or None (if None was passed in), or the value
        of empty_rectangle (from enable.base) if the intersection resulted
        in a null region.
        """
        if view_bounds:
            # Check if we are visible
            tmp = intersect_bounds(self.position + self.bounds, view_bounds)
            if tmp == empty_rectangle:
                return empty_rectangle
            # Compute new_bounds, which is the view_bounds transformed into
            # our coordinate space
            v = view_bounds
            new_bounds = (v[0] - self.x, v[1] - self.y, v[2], v[3])
        else:
            new_bounds = None
        return new_bounds

    def _component_bounds_changed(self, component):
        "Called by contained objects when their bounds change"
        # For now, just punt and call compact()
        if self.auto_size:
            self.compact()

    def _component_position_changed(self, component):
        "Called by contained objects when their position changes"
        # For now, just punt and call compact()
        if self.auto_size:
            self.compact()

    #------------------------------------------------------------------------
    # Deprecated interface
    #------------------------------------------------------------------------

    def _draw_overlays(self, gc, view_bounds=None, mode="normal"):
        """ Method for backward compatability with old drawing scheme.
        """
        warnings.warn("Containter._draw_overlays is deprecated.")
        for component in self.overlays:
            component.overlay(component, gc, view_bounds, mode)
        return

    #------------------------------------------------------------------------
    # Property setters & getters
    #------------------------------------------------------------------------

    def _get_components(self):
        return self._components

    def _get_layout_needed(self):
        # Override the parent implementation to take into account whether any
        # of our contained components need layout.
        if self._layout_needed:
            return True
        else:
            for c in self.components:
                if c.layout_needed:
                    return True
            else:
                return False

    #------------------------------------------------------------------------
    # Interactor interface
    #------------------------------------------------------------------------

    def normal_mouse_leave(self, event):
        event.push_transform(self.get_event_transform(event), caller=self)
        for component in self._prev_event_handlers:
            component.dispatch(event, "mouse_leave")
        self._prev_event_handlers = set()
        event.pop(caller=self)

    def _container_handle_mouse_event(self, event, suffix):
        """
        This method allows the container to handle a mouse event before its
        children get to see it.  Once the event gets handled, its .handled
        should be set to True, and contained components will not be called
        with the event.
        """
        #super(Container, self)._dispatch_stateful_event(event, suffix)
        Component._dispatch_stateful_event(self, event, suffix)

    def get_event_transform(self, event=None, suffix=""):
        return affine.affine_from_translation(-self.x, -self.y)

    def _dispatch_stateful_event(self, event, suffix):
        """
        Dispatches a mouse event based on the current event_state.  Overrides
        the default Interactor._dispatch_stateful_event by adding some default
        behavior to send all events to our contained children.

        "suffix" is the name of the mouse event as a suffix to the event state
        name, e.g. "_left_down" or "_window_enter".
        """
        if not event.handled:
            if isinstance(event, BlobFrameEvent):
                # This kind of event does not have a meaningful location. Just
                # let all of the child components see it.
                for component in self._components[::-1]:
                    component.dispatch(event, suffix)
                return

            components = self.components_at(event.x, event.y)

            # Translate the event's location to be relative to this container
            event.push_transform(self.get_event_transform(event, suffix),
                                 caller=self)

            try:
                new_component_set = set(components)

                # For "real" mouse events (i.e., not pre_mouse_* events),
                # notify the previous listening components of a mouse or
                # drag leave
                if not suffix.startswith("pre_"):
                    components_left = self._prev_event_handlers - new_component_set
                    if components_left:
                        leave_event = None
                        if isinstance(event, MouseEvent):
                            leave_event = event
                            leave_suffix = "mouse_leave"
                        elif isinstance(event, DragEvent):
                            leave_event = event
                            leave_suffix = "drag_leave"
                        elif isinstance(event, (BlobEvent, BlobFrameEvent)):
                            # Do not generate a 'leave' event.
                            pass
                        else:
                            # TODO: think of a better way to handle this rare case?
                            leave_event = MouseEvent(x=event.x,
                                                     y=event.y,
                                                     window=event.window)
                            leave_suffix = "mouse_leave"

                        if leave_event is not None:
                            for component in components_left:
                                component.dispatch(leave_event,
                                                   "pre_" + leave_suffix)
                                component.dispatch(leave_event, leave_suffix)
                                event.handled = False

                    # Notify new components of a mouse enter, if the event is
                    # not a mouse_leave or a drag_leave
                    if suffix not in ("mouse_leave", "drag_leave"):
                        components_entered = \
                            new_component_set - self._prev_event_handlers
                        if components_entered:
                            enter_event = None
                            if isinstance(event, MouseEvent):
                                enter_event = event
                                enter_suffix = "mouse_enter"
                            elif isinstance(event, DragEvent):
                                enter_event = event
                                enter_suffix = "drag_enter"
                            elif isinstance(event,
                                            (BlobEvent, BlobFrameEvent)):
                                # Do not generate an 'enter' event.
                                pass
                            if enter_event:
                                for component in components_entered:
                                    component.dispatch(enter_event,
                                                       "pre_" + enter_suffix)
                                    component.dispatch(enter_event,
                                                       enter_suffix)
                                    event.handled = False

                # Handle the actual event
                # Only add event handlers to the list of previous event handlers
                # if they actually receive the event (and the event is not a
                # pre_* event.
                if not suffix.startswith("pre_"):
                    self._prev_event_handlers = set()
                for component in components:
                    component.dispatch(event, suffix)
                    if not suffix.startswith("pre_"):
                        self._prev_event_handlers.add(component)
                    if event.handled:
                        break
            finally:
                event.pop(caller=self)

            if not event.handled:
                self._container_handle_mouse_event(event, suffix)

        return

    #------------------------------------------------------------------------
    # Event handlers
    #------------------------------------------------------------------------

    def _auto_size_changed(self, old, new):
        # For safety, re-compute our bounds
        if new == True:
            self.compact()
        else:
            pass
        return

    def _window_resized(self, newsize):
        if newsize is not None:
            self.bounds = [newsize[0] - self.x, newsize[1] - self.y]

    #FIXME: Need a _window_changed to remove this handler if the window changes

    def _fit_window_changed(self, old, new):
        if self._window is not None:
            if not self.fit_window:
                self._window.on_trait_change(self._window_resized,
                                             "resized",
                                             remove=True)
            else:
                self._window.on_trait_change(self._window_resized, "resized")
        return

    def _bounds_changed(self, old, new):
        # crappy... calling our parent's handler seems like a common traits
        # event handling problem
        super(Container, self)._bounds_changed(old, new)
        self._layout_needed = True
        self.invalidate_draw()

    def _bounds_items_changed(self, event):
        super(Container, self)._bounds_items_changed(event)
        self._layout_needed = True
        self.invalidate_draw()

    def _bgcolor_changed(self):
        self.invalidate_draw()
        self.request_redraw()

    def __components_items_changed(self, event):
        self._layout_needed = True

    def __components_changed(self, event):
        self._layout_needed = True
        self.invalidate_draw()

    #-------------------------------------------------------------------------
    # Old / deprecated draw methods; here for backwards compatibility
    #-------------------------------------------------------------------------

    def _draw_component(self, gc, view_bounds=None, mode="normal"):
        """ Draws the component.

        This method is preserved for backwards compatibility. Overrides
        the implementation in Component.
        """
        with gc:
            gc.set_antialias(False)

            self._draw_container(gc, mode)
            self._draw_background(gc, view_bounds, mode)
            self._draw_underlay(gc, view_bounds, mode)
            self._draw_children(gc, view_bounds,
                                mode)  #This was children_draw_mode
            self._draw_overlays(gc, view_bounds, mode)
        return

    def _draw_children(self, gc, view_bounds=None, mode="normal"):

        new_bounds = self._transform_view_bounds(view_bounds)
        if new_bounds == empty_rectangle:
            return

        with gc:
            gc.set_antialias(False)
            gc.translate_ctm(*self.position)
            for component in self.components:
                if new_bounds:
                    tmp = intersect_bounds(
                        component.outer_position + component.outer_bounds,
                        new_bounds)
                    if tmp == empty_rectangle:
                        continue
                with gc:
                    component.draw(gc, new_bounds, mode)
        return
示例#9
0
class Property(HasStrictTraits):
    color = Tuple(Range(0.0, 1.0), Range(0.0, 1.0), Range(0.0, 1.0))
    opacity = Range(0.0, 1.0, 1.0)
    representation = representation_trait
示例#10
0
class MousePickDispatcher(HasTraits):
    """ An event dispatcher to send pick event on mouse clicks.

        This objects wires VTK observers so that picking callbacks
        can be bound to mouse click without movement.

        The object deals with adding and removing the VTK-level
        callbacks.
    """

    # The scene events are wired to.
    scene = WeakRef(Scene)

    # The list of callbacks, with the picker type they should be using,
    # and the mouse button that triggers them.
    callbacks = List(Tuple(
                        Callable,
                        Enum('cell', 'point', 'world'),
                        Enum('Left', 'Middle', 'Right'),
                        ),
                    help="The list of callbacks, with the picker type they "
                         "should be using, and the mouse button that "
                         "triggers them. The callback is passed "
                         "as an argument the tvtk picker."
                    )

    #--------------------------------------------------------------------------
    # Private traits
    #--------------------------------------------------------------------------

    # Whether the mouse has moved after the button press
    _mouse_no_mvt = Int

    # The button that has been pressed
    _current_button = Enum('Left', 'Middle', 'Right')

    # The various picker that are used when the mouse is pressed
    _active_pickers = Dict

    # The VTK callback numbers corresponding to our callbacks
    _picker_callback_nbs = Dict(value_trait=Int)

    # The VTK callback numbers corresponding to mouse movement
    _mouse_mvt_callback_nb = Int

    # The VTK callback numbers corresponding to mouse press
    _mouse_press_callback_nbs = Dict

    # The VTK callback numbers corresponding to mouse release
    _mouse_release_callback_nbs = Dict

    #--------------------------------------------------------------------------
    # Callbacks management
    #--------------------------------------------------------------------------

    @on_trait_change('callbacks_items')
    def dispatch_callbacks_change(self, name, trait_list_event):
        for item in trait_list_event.added:
            self.callback_added(item)
        for item in trait_list_event.removed:
            self.callback_removed(item)


    def callback_added(self, item):
        """ Wire up the different VTK callbacks.
        """
        callback, type, button = item
        picker = getattr(self.scene.scene.picker, '%spicker' % type)
        self._active_pickers[type] = picker

        # Register the pick callback
        if not type in self._picker_callback_nbs:
            self._picker_callback_nbs[type] = \
                            picker.add_observer("EndPickEvent",
                                                self.on_pick)

        # Register the callbacks on the scene interactor
        if VTK_VERSION>5:
            move_event = "RenderEvent"
        else:
            move_event = 'MouseMoveEvent'
        if not self._mouse_mvt_callback_nb:
            self._mouse_mvt_callback_nb = \
                self.scene.scene.interactor.add_observer(move_event,
                                                self.on_mouse_move)
        if not button in self._mouse_press_callback_nbs:
            self._mouse_press_callback_nbs[button] = \
                self.scene.scene.interactor.add_observer(
                                    '%sButtonPressEvent' % button,
                                    self.on_button_press)
        if VTK_VERSION>5:
            release_event = "EndInteractionEvent"
        else:
            release_event = '%sButtonReleaseEvent' % button
        if not button in self._mouse_release_callback_nbs:
            self._mouse_release_callback_nbs[button] = \
                self.scene.scene.interactor.add_observer(
                                    release_event,
                                    self.on_button_release)


    def callback_removed(self, item):
        """ Clean up the unecessary VTK callbacks.
        """
        callback, type, button = item

        # If the picker is no longer needed, clean up its observers.
        if not [t for c, t, b in self.callbacks if t == type]:
            picker = self._active_pickers[type]
            picker.remove_observer(self._picker_callback_nbs[type])
            del self._active_pickers[type]

        # If there are no longer callbacks on the button, clean up
        # the corresponding observers.
        if not [b for c, t, b in self.callbacks if b == button]:
            self.scene.scene.interactor.remove_observer(
                    self._mouse_press_callback_nbs[button])
            self.scene.scene.interactor.remove_observer(
                    self._mouse_release_callback_nbs[button])
        if len(self.callbacks) == 0 and self._mouse_mvt_callback_nb:
            self.scene.scene.interactor.remove_observer(
                            self._mouse_mvt_callback_nb)
            self._mouse_mvt_callback_nb = 0


    def clear_callbacks(self):
        while self.callbacks:
            self.callbacks.pop()

    #--------------------------------------------------------------------------
    # Mouse movement dispatch mechanism
    #--------------------------------------------------------------------------

    def on_button_press(self, vtk_picker, event):
        self._current_button = event[:-len('ButtonPressEvent')]
        self._mouse_no_mvt = 2


    def on_mouse_move(self, vtk_picker, event):
        if self._mouse_no_mvt:
            self._mouse_no_mvt -= 1


    def on_button_release(self, vtk_picker, event):
        """ If the mouse has not moved, pick with our pickers.
        """
        if self._mouse_no_mvt:
            x, y = vtk_picker.GetEventPosition()
            for picker in self._active_pickers.values():
                try:
                    picker.pick((x, y, 0), self.scene.scene.renderer)
                except TypeError:
                    picker.pick(x, y, 0, self.scene.scene.renderer)
        self._mouse_no_mvt = 0


    def on_pick(self, vtk_picker, event):
        """ Dispatch the pick to the callback associated with the
            corresponding mouse button.
        """
        picker = tvtk.to_tvtk(vtk_picker)
        for event_type, event_picker in self._active_pickers.items():
            if picker is event_picker:
                for callback, type, button in self.callbacks:
                    if ( type == event_type
                                    and button == self._current_button):
                        callback(picker)
                break

    #--------------------------------------------------------------------------
    # Private methods
    #--------------------------------------------------------------------------

    def __del__(self):
        self.clear_callbacks()
示例#11
0
class Office(HasTraits):
    location = Tuple(Float, Float)
    city = Str
示例#12
0
class GeoNDGrid(Source):
    '''
    Specification and representation of an nD-grid.

    GridND
    '''
    # The name of our scalar array.
    scalar_name = Str('scalar')

    # map of coordinate labels to the indices
    _dim_map = {'x': 0, 'y': 1, 'z': 2}

    # currently active dimensions
    active_dims = List(Str, ['x', 'y'])

    # Bottom left corner
    x_mins = Instance(GridPoint, label='Corner 1')

    def _x_mins_default(self):
        '''Bottom left corner'''
        return GridPoint()
    # Upper right corner
    x_maxs = Instance(GridPoint, label='Corner 2')

    def _x_maxs_default(self):
        '''Upper right corner'''
        return GridPoint(x=1, y=1, z=1)

    # indices of the currently active dimensions
    dim_indices = Property(Array(int), depends_on='active_dims')

    @cached_property
    def _get_dim_indices(self):
        ''' Get active indices '''
        return array([self._dim_map[dim_ix]
                      for dim_ix
                      in self.active_dims], dtype='int_')

    # number of currently active dimensions
    n_dims = Property(Int, depends_on='active_dims')

    @cached_property
    def _get_n_dims(self):
        '''Number of currently active dimensions'''
        return len(self.active_dims)

    # number of elements in each direction
    # @todo: rename to n_faces
    shape = Tuple(int, int, int, label='Elements')

    def _shape_default(self):
        '''Number of elements in each direction'''
        return (1, 0, 0)

    n_act_nodes = Property(Array, depends_on='shape, active_dims')

    @cached_property
    def _get_n_act_nodes(self):
        '''Number of active nodes respecting the active_dim'''
        act_idx = ones((3, ), int)
        shape = array(list(self.shape), dtype=int)
        act_idx[self.dim_indices] += shape[self.dim_indices]
        return act_idx

    # total number of nodes of the system grid
    n_nodes = Property(Int, depends_on='shape, active_dims')

    @cached_property
    def _get_n_nodes(self):
        '''Number of nodes used for the geometry approximation'''
        return reduce(lambda x, y: x * y, self.n_act_nodes)

    enum_nodes = Property(Array, depends_on='shape,active_dims')

    @cached_property
    def _get_enum_nodes(self):
        '''
        Returns an array of element numbers respecting the grid structure
        (the nodes are numbered first in x-direction, then in y-direction and
        last in z-direction)
        '''
        return arange(self.n_nodes).reshape(tuple(self.n_act_nodes))

    grid = Property(Array,
                    depends_on='shape,active_dims,x_mins.+,x_maxs.+')

    @cached_property
    def _get_grid(self):
        '''
        slice(start,stop,step) with step of type 'complex' leads to that number of divisions 
        in that direction including 'stop' (see numpy: 'mgrid')
        '''
        slices = [slice(x_min, x_max, complex(0, n_n))
                  for x_min, x_max, n_n
                  in zip(self.x_mins, self.x_maxs, self.n_act_nodes)]
        return mgrid[tuple(slices)]

    #-------------------------------------------------------------------------
    # Visualization pipelines
    #-------------------------------------------------------------------------
#     mvp_mgrid_geo = Trait(MVPolyData)
#
#     def _mvp_mgrid_geo_default(self):
#         return MVPolyData(name='Mesh geomeetry',
#                           points=self._get_points,
#                           lines=self._get_lines,
#                           polys=self._get_faces,
#                           scalars=self._get_random_scalars
#                           )
#
#     mvp_mgrid_labels = Trait(MVPointLabels)
#
#     def _mvp_mgrid_labels_default(self):
#         return MVPointLabels(name='Mesh numbers',
#                              points=self._get_points,
#                              scalars=self._get_random_scalars,
#                              vectors=self._get_points)

    changed = Button('Draw')

    @on_trait_change('changed')
    def redraw(self):
        '''
        '''
        self.mvp_mgrid_geo.redraw()
        self.mvp_mgrid_labels.redraw('label_scalars')

    def _get_points(self):
        '''
        Reshape the grid into a column.
        '''
        return c_[tuple([self.grid[i].flatten() for i in range(3)])]

    def _get_n_lines(self):
        '''
        Get the number of lines.
        '''
        act_idx = ones((3, ), int)
        act_idx[self.dim_indices] += self.shape[self.dim_indices]
        return reduce(lambda x, y: x * y, act_idx)

    def _get_lines(self):
        '''
        Only return data if n_dims = 1
        '''
        if self.n_dims != 1:
            return array([], int)
        #
        # Get the list of all base nodes
        #
        tidx = ones((3,), dtype='int_')
        tidx[self.dim_indices] = -1
        slices = tuple([slice(0, idx) for idx in tidx])
        base_node_list = self.enum_nodes[slices].flatten()
        #
        # Get the node map within the line
        #
        ijk_arr = zeros((3, 2), dtype=int)
        ijk_arr[self.dim_indices[0]] = [0, 1]
        offsets = self.enum_nodes[ijk_arr[0], ijk_arr[1], ijk_arr[2]]
        #
        # Setup and fill the array with line connectivities
        #
        n_lines = self._get_n_lines()
        lines = zeros((n_lines, 2), dtype='int_')
        for n_idx, base_node in enumerate(base_node_list):
            lines[n_idx, :] = offsets + base_node
        return lines

    def _get_n_faces(self):
        '''Return the number of faces.

        The number is determined by putting 1 into inactive dimensions and 
        shape into the active dimensions. 
        '''
        act_idx = ones((3, ), int)
        shape = array(self.shape, dtype=int)
        act_idx[self.dim_indices] = shape[self.dim_indices]
        return reduce(lambda x, y: x * y, act_idx)

    def _get_faces(self):
        '''
        Only return data of n_dims = 2.
        '''
        if self.n_dims != 2:
            return array([], int)
        #
        # get the slices extracting all corner nodes with
        # the smallest node number within the element
        #
        tidx = ones((3,), dtype='int_')
        tidx[self.dim_indices] = -1
        slices = tuple([slice(0, idx) for idx in tidx])
        base_node_list = self.enum_nodes[slices].flatten()
        #
        # get the node map within the face
        #
        ijk_arr = zeros((3, 4), dtype=int)
        ijk_arr[self.dim_indices[0]] = [0, 0, 1, 1]
        ijk_arr[self.dim_indices[1]] = [0, 1, 1, 0]
        offsets = self.enum_nodes[ijk_arr[0], ijk_arr[1], ijk_arr[2]]
        #
        # setup and fill the array with line connectivities
        #
        n_faces = self._get_n_faces()
        faces = zeros((n_faces, 4), dtype='int_')
        for n_idx, base_node in enumerate(base_node_list):
            faces[n_idx, :] = offsets + base_node
        return faces

    def _get_volumes(self):
        '''
        Only return data if ndims = 3
        '''
        if self.n_dims != 3:
            return array([], int)

        tidx = ones((3,), dtype='int_')
        tidx[self.dim_indices] = -1
        slices = tuple([slice(0, idx) for idx in tidx])

        en = self.enum_nodes
        offsets = array([en[0, 0, 0], en[0, 1, 0], en[1, 1, 0], en[1, 0, 0],
                         en[0, 0, 1], en[0, 1, 1], en[1, 1, 1], en[1, 0, 1]], dtype='int_')
        base_node_list = self.enum_nodes[slices].flatten()

        n_faces = self._get_n_faces()
        faces = zeros((n_faces, 8), dtype='int_')
        for n in base_node_list:
            faces[n, :] = offsets + n

        return faces

    # Identifiers
    var = Str('dummy')
    idx = Int(0)

    def _get_random_scalars(self):
        return random.weibull(1, size=self.n_nodes)

    traits_view = View(HSplit(Group(Item('changed', show_label=False),
                                    Item('active_dims@',
                                         editor=CheckListEditor(values=['x', 'y', 'z'],
                                                                cols=3)),
                                    Item('x_mins@', resizable=False),
                                    Item('x_maxs@'),
                                    Item('shape@'),
                                    ),
                              ),
                       resizable=True)
示例#13
0
class SplashScreen(MSplashScreen, Window):
    """ The toolkit specific implementation of a SplashScreen.  See the
    ISplashScreen interface for the API documentation.
    """

    #### 'ISplashScreen' interface ############################################

    image = Instance(ImageResource, ImageResource('splash'))

    log_level = Int(DEBUG)

    show_log_messages = Bool(True)

    text = Unicode

    text_color = Any

    text_font = Any

    text_location = Tuple(5, 5)

    ###########################################################################
    # Protected 'IWidget' interface.
    ###########################################################################

    def _create_control(self, parent):
        # Get the splash screen image.
        image = self.image.create_image()

        splash_screen = wx.SplashScreen(
            # The bitmap to display on the splash screen.
            image.ConvertToBitmap(),
            # Splash Style.
            wx.SPLASH_NO_TIMEOUT | wx.SPLASH_CENTRE_ON_SCREEN,
            # Timeout in milliseconds (we don't currently timeout!).
            0,
            # The parent of the splash screen.
            parent,
            # wx Id.
            -1,
            # Window style.
            style=wx.SIMPLE_BORDER | wx.FRAME_NO_TASKBAR)

        # By default we create a font slightly bigger and slightly more italic
        # than the normal system font ;^)  The font is used inside the event
        # handler for 'EVT_PAINT'.
        self._wx_default_text_font = new_font_like(
            wx.NORMAL_FONT,
            point_size=wx.NORMAL_FONT.GetPointSize() + 1,
            style=wx.ITALIC)

        # This allows us to write status text on the splash screen.
        wx.EVT_PAINT(splash_screen, self._on_paint)

        return splash_screen

    ###########################################################################
    # Private interface.
    ###########################################################################

    def _text_changed(self):
        """ Called when the splash screen text has been changed. """

        # Passing 'False' to 'Refresh' means "do not erase the background".
        if self.control is not None:
            self.control.Refresh(False)
            self.control.Update()
        wx.GetApp().Yield(True)

    def _on_paint(self, event):
        """ Called when the splash window is being repainted. """

        if self.control is not None:
            # Get the window that the splash image is drawn in.
            window = self.control.GetSplashWindow()

            dc = wx.PaintDC(window)

            if self.text_font is None:
                text_font = self._wx_default_text_font
            else:
                text_font = self.text_font

            dc.SetFont(text_font)

            if self.text_color is None:
                text_color = 'black'
            else:
                text_color = self.text_color

            dc.SetTextForeground(text_color)

            x, y = self.text_location
            dc.DrawText(self.text, x, y)

        # Let the normal wx paint handling do its stuff.
        event.Skip()
示例#14
0
文件: vixen.py 项目: ajaypraj/vixen
class VixenUI(HasTraits):

    vixen = Instance(Vixen)

    mode = Enum('edit', 'view')

    editor = Instance(ProjectEditor)

    viewer = Instance(ProjectViewer)

    processor = Instance(Processor)

    is_busy = Bool(False)

    docs = Property(Str)

    log_file = Str

    version = Str

    message = Tuple()

    # Private trait to generate message counts.
    _message_count = Int

    def setup_logging_handler(self):
        handler = UIErrorHandler(self)
        handler.setLevel(logging.ERROR)
        root = logging.getLogger()
        root.addHandler(handler)

    def get_context(self):
        return dict(ui=self,
                    vixen=self.vixen,
                    editor=self.editor,
                    viewer=self.viewer)

    def home(self):
        self.mode = 'edit'

    def notify_user(self, message, kind):
        """Meant to just notify the user from the Python side.
        """
        mid = self._get_message_id()
        self.message = message, kind, mid

    def error(self, msg):
        self.notify_user(msg, 'error')
        logger.info("ERROR: %s", msg)

    def info(self, msg):
        self.notify_user(msg, 'info')
        logger.info("INFO: %s", msg)

    def log(self, msg, kind='info'):
        """This method is meant to be called from the Javascript side.
        """
        if kind == 'info':
            logger.info(msg)
        elif kind == 'error':
            logger.error(msg)
        else:
            logger.error('Unknown message kind: %s', kind)
            logger.info(msg)

    def success(self, msg):
        self.notify_user(msg, 'success')
        logger.info("SUCCESS: %s", msg)

    def edit(self, project):
        logger.info('Edit project: %s', project.name)
        self.editor.project = project
        self.mode = 'edit'
        self.info('Remember to "Apply changes" if you change anything.')

    def view(self, project):
        logger.info('View project: %s', project.name)
        self.viewer.project = project
        self.mode = 'view'
        self.editor.project = None
        self.info('Remember to "Save" if you edit any tags.')

    def process(self, project):
        jobs = []
        for proc in project.processors:
            if self.viewer.is_searching:
                to_process = [x[1] for x in self.viewer.search_pager.data]
            else:
                to_process = project.keys()
            jobs.extend(proc.make_jobs(to_process, project))
        self.processor.jobs = jobs
        self.processor.process()
        self.info("Remember to save the project once processing completes.")

    def remove(self, project):
        name = project.name
        logger.info('Removing project: %s', name)
        self.vixen.remove(project)
        self.editor.project = None
        self.info('Removed project: %s' % name)

    def add_project(self):
        name = 'Project%d' % (len(self.vixen.projects))
        p = Project(name=name)
        self.vixen.add(p)
        self.editor.project = p
        logger.info('Added project %s', name)

    def copy_project(self, project):
        name = project.name
        logger.info('Copying project: %s', name)
        if exists(project.save_file) and project.number_of_files == 0:
            project.load()
        p1 = project.copy()
        self.vixen.add(p1)
        self.editor.project = p1

    def save(self):
        with self.busy():
            if self.mode == 'edit':
                if self.editor is not None and self.editor.project is not None:
                    self.editor.apply()
            elif self.mode == 'view':
                if self.viewer.project is not None:
                    self.viewer.project.save()

    def halt(self):
        """Shut down the webserver.
        """
        logger.info('**** Halting ViXeN ****')
        from tornado.ioloop import IOLoop
        ioloop = IOLoop.instance()
        ioloop.stop()

    @contextmanager
    def busy(self):
        self.is_busy = True
        try:
            yield
        finally:
            self.is_busy = False

    def _get_docs(self):
        mydir = dirname(__file__)
        build = join(dirname(mydir), 'docs', 'build', 'html', 'index.html')
        bundled = join(dirname(mydir), 'vixen_data', 'docs', 'html',
                       'index.html')
        root = '/' if sys.platform.startswith('win') else ''
        if exists(bundled):
            return root + bundled
        elif exists(build):
            return root + build
        else:
            return 'http://vixen.readthedocs.io'

    def _log_file_default(self):
        return join(get_project_dir(), 'vixen.log')

    def _vixen_default(self):
        v = Vixen()
        v.load()
        return v

    def _editor_default(self):
        return ProjectEditor(ui=self)

    def _viewer_default(self):
        return ProjectViewer(ui=self)

    def _processor_default(self):
        return Processor()

    def _version_default(self):
        import vixen
        return vixen.__version__

    def _get_message_id(self):
        mc = self._message_count
        orig = mc
        mc += 1
        if mc > 100:
            mc = 0
        self._message_count = mc
        return orig
示例#15
0
class Perspective(HasTraits):
    """ The default perspective. """

    # The ID of the default perspective.
    DEFAULT_ID = "pyface.workbench.default"

    # The name of the default perspective.
    DEFAULT_NAME = "Default"

    # 'IPerspective' interface ---------------------------------------------

    # The perspective's unique identifier (unique within a workbench window).
    id = Str(DEFAULT_ID)

    # The perspective's name.
    name = Str(DEFAULT_NAME)

    # The contents of the perspective.
    contents = List(PerspectiveItem)

    # The size of the editor area in this perspective. A value of (-1, -1)
    # indicates that the workbench window should choose an appropriate size
    # based on the sizes of the views in the perspective.
    editor_area_size = Tuple((-1, -1))

    # Is the perspective enabled?
    enabled = Bool(True)

    # Should the editor area be shown in this perspective?
    show_editor_area = Bool(True)

    # ------------------------------------------------------------------------
    # 'object' interface.
    # ------------------------------------------------------------------------

    def __str__(self):
        """ Return an informal string representation of the object. """

        return "Perspective(%s)" % self.id

    # ------------------------------------------------------------------------
    # 'Perspective' interface.
    # ------------------------------------------------------------------------

    # Initializers ---------------------------------------------------------

    def _id_default(self):
        """ Trait initializer. """

        # If no Id is specified then use the name.
        return self.name

    # Methods -------------------------------------------------------------#

    def create(self, window):
        """ Create the perspective in a workbench window.

        For most cases you should just be able to set the 'contents' trait to
        lay out views as required. However, you can override this method if
        you want to have complete control over how the perspective is created.

        """

        # Set the size of the editor area.
        if self.editor_area_size != (-1, -1):
            window.editor_area_size = self.editor_area_size

        # If the perspective has specific contents then add just those.
        if len(self.contents) > 0:
            self._add_contents(window, self.contents)

        # Otherwise, add all of the views defined in the window at their
        # default positions realtive to the editor area.
        else:
            self._add_all(window)

        # Activate the first view in every region.
        window.reset_views()

    def show(self, window):
        """ Called when the perspective is shown in a workbench window.

        The default implementation does nothing, but you can override this
        method if you want to do something whenever the perspective is
        activated.

        """

        return

    # ------------------------------------------------------------------------
    # Private interface.
    # ------------------------------------------------------------------------

    def _add_contents(self, window, contents):
        """ Adds the specified contents. """

        # If we are adding specific contents then we ignore any default view
        # visibility.
        #
        # fixme: This is a bit ugly! Why don't we pass the visibility in to
        # 'window.add_view'?
        for view in window.views:
            view.visible = False

        for item in contents:
            self._add_perspective_item(window, item)

    def _add_perspective_item(self, window, item):
        """ Adds a perspective item to a window. """

        # If no 'relative_to' is specified then the view is positioned
        # relative to the editor area.
        if len(item.relative_to) > 0:
            relative_to = window.get_view_by_id(item.relative_to)

        else:
            relative_to = None

        # fixme: This seems a bit ugly, having to reach back up to the
        # window to get the view. Maybe its not that bad?
        view = window.get_view_by_id(item.id)
        if view is not None:
            # fixme: This is probably not the ideal way to sync view traits
            # and perspective_item traits.
            view.style_hint = item.style_hint
            # Add the view to the window.
            window.add_view(view, item.position, relative_to,
                            (item.width, item.height))

        else:
            # The reason that we don't just barf here is that a perspective
            # might use views from multiple plugins, and we probably want to
            # continue even if one or two of them aren't present.
            #
            # fixme: This is worth keeping an eye on though. If we end up with
            # a strict mode that throws exceptions early and often for
            # developers, then this might be a good place to throw one ;^)
            logger.error("missing view for perspective item <%s>" % item.id)

    def _add_all(self, window):
        """ Adds *all* of the window's views defined in the window. """

        for view in window.views:
            if view.visible:
                self._add_view(window, view)

    def _add_view(self, window, view):
        """ Adds a view to a window. """

        # If no 'relative_to' is specified then the view is positioned
        # relative to the editor area.
        if len(view.relative_to) > 0:
            relative_to = window.get_view_by_id(view.relative_to)

        else:
            relative_to = None

        # Add the view to the window.
        window.add_view(view, view.position, relative_to,
                        (view.width, view.height))

        return
示例#16
0
class TextPlot(BaseXYPlot):
    """ A plot that positions textual labels in 2D """

    #: text values corresponding to indices
    text = Instance(ArrayDataSource)

    #: The font of the tick labels.
    text_font = KivaFont('modern 10')

    #: The color of the tick labels.
    text_color = black_color_trait

    #: The rotation of the tick labels.
    text_rotate_angle = Float(0)

    #: The margin around the label.
    text_margin = Int(2)

    #: horizontal position of text relative to target point
    h_position = Enum("center", "left", "right")

    #: vertical position of text relative to target point
    v_position = Enum("center", "top", "bottom")

    #: offset of text relative to non-index direction in pixels
    text_offset = Tuple(Float, Float)

    #------------------------------------------------------------------------
    # Private traits
    #------------------------------------------------------------------------

    #: flag for whether the cache of Label instances is valid
    _label_cache_valid = Bool(False)

    #: cache of Label instances for faster rendering
    _label_cache = List

    #: cache of bounding boxes of labels
    _label_box_cache = List

    #------------------------------------------------------------------------
    # Private methods
    #------------------------------------------------------------------------

    def _compute_labels(self, gc):
        """Generate the Label instances for the plot. """
        self._label_cache = [
            Label(
                text=text,
                font=self.text_font,
                color=self.text_color,
                rotate_angle=self.text_rotate_angle,
                margin=self.text_margin
            ) for text in self.text.get_data()
        ]
        self._label_box_cache = [
            array(label.get_bounding_box(gc), float)
            for label in self._label_cache
        ]
        self._label_cache_valid = True

    def _gather_points(self):
        """ Abstract method to collect data points that are within the range of
        the plot, and cache them.
        """
        if self._cache_valid:
            return

        if not self.index or not self.value:
            return

        index, index_mask = self.index.get_data_mask()
        value, value_mask = self.value.get_data_mask()

        if len(index) == 0 or len(value) == 0 or len(index) != len(value):
            self._cached_data_pts = []
            self._cached_point_mask = []
            self._cache_valid = True
            return

        index_range_mask = self.index_mapper.range.mask_data(index)
        value_range_mask = self.value_mapper.range.mask_data(value)

        nan_mask = (
            isfinite(index) & index_mask & isfinite(value) & value_mask
        )
        point_mask = nan_mask & index_range_mask & value_range_mask

        if not self._cache_valid:
            if not point_mask.all():
                points = column_stack([index[point_mask], value[point_mask]])
            else:
                points = column_stack([index, value])
            self._cached_data_pts = points
            self._cached_point_mask = point_mask
            self._cache_valid = True

    def _render(self, gc, pts):
        if not self._label_cache_valid:
            self._compute_labels(gc)

        labels = [
            label
            for label, mask in zip(self._label_cache, self._cached_point_mask)
            if mask
        ]
        boxes = [
            label
            for label, mask in
            zip(self._label_box_cache, self._cached_point_mask) if mask
        ]
        offset = empty((2, ), float)

        with gc:
            gc.clip_to_rect(self.x, self.y, self.width, self.height)
            for pt, label, box in sm.zip(pts, labels, boxes):
                with gc:
                    if self.h_position == "center":
                        offset[0] = -box[0] / 2 + self.text_offset[0]
                    elif self.h_position == "right":
                        offset[0] = self.text_offset[0]
                    elif self.h_position == "left":
                        offset[0] = -box[0] / 2 + self.text_offset[0]
                    if self.v_position == "center":
                        offset[1] = -box[1] / 2 + self.text_offset[1]
                    elif self.v_position == "top":
                        offset[1] = self.text_offset[1]
                    elif self.v_position == "bottom":
                        offset[1] = -box[1] / 2 - self.text_offset[1]

                    pt += offset
                    gc.translate_ctm(*pt)

                    label.draw(gc)

    #------------------------------------------------------------------------
    # Trait events
    #------------------------------------------------------------------------

    @on_trait_change("index.data_changed")
    def _invalidate(self):
        self._cache_valid = False
        self._screen_cache_valid = False
        self._label_cache_valid = False

    @on_trait_change("value.data_changed")
    def _invalidate_labels(self):
        self._label_cache_valid = False
示例#17
0
class HeatmapRendererStyle(BaseRendererStyle):
    """ Styling class for heatmap renderers (cmapped image plot).
    """
    #: Transparency of the renderer
    alpha = Range(value=1., low=0., high=1.)

    #: Name of the palette to pick colors from in color direction
    color_palette = Enum(ALL_CHACO_PALETTES)

    #: Chaco color mapper to provide to plot.plot for a cmap_scatter type
    colormap = Property(Any, depends_on="color_palette")

    # Note: this count be encoded in an AxisStyle
    xbounds = Tuple((0, 1))

    auto_xbounds = Tuple((0, 1))

    ybounds = Tuple((0, 1))

    auto_ybounds = Tuple((0, 1))

    reset_xbounds = Button("Reset")

    reset_ybounds = Button("Reset")

    def __init__(self, **traits):
        if "xbounds" in traits and "auto_xbounds" not in traits:
            traits["auto_xbounds"] = traits["xbounds"]

        if "ybounds" in traits and "auto_ybounds" not in traits:
            traits["auto_ybounds"] = traits["ybounds"]

        super(HeatmapRendererStyle, self).__init__(**traits)

    def traits_view(self):
        view = self.view_klass(
            VGroup(
                HGroup(
                    Item('color_palette'),
                    Item('alpha', label="Transparency"),
                ),
                HGroup(
                    Item('xbounds'),
                    Item('reset_xbounds', show_label=False),
                ),
                HGroup(
                    Item('ybounds'),
                    Item('reset_ybounds', show_label=False),
                ),
            ), )
        return view

    # Traits listener methods -------------------------------------------------

    def _reset_xbounds_fired(self):
        self.xbounds = self.auto_xbounds

    def _reset_ybounds_fired(self):
        self.ybounds = self.auto_ybounds

    # Traits initialization methods -------------------------------------------

    def _get_colormap(self):
        if self.color_palette in color_map_name_dict:
            return color_map_name_dict[self.color_palette]

    def _color_palette_default(self):
        return DEFAULT_CONTIN_PALETTE

    def _dict_keys_default(self):
        return ["colormap", "xbounds", "ybounds"]
示例#18
0
class GUIApplication(Application):
    """ A basic Pyface GUI application. """

    # 'GUIApplication' traits -------------------------------------------------

    # Branding ---------------------------------------------------------------

    #: The splash screen for the application. No splash screen by default
    splash_screen = Instance(ISplashScreen)

    #: The about dialog for the application.
    about_dialog = Instance(IDialog)

    #: Icon for the application (used in window titlebars)
    icon = Image

    #: Logo of the application (used in splash screens and about dialogs)
    logo = Image

    # Window management ------------------------------------------------------

    #: The window factory to use when creating a window for the application.
    window_factory = Callable(default_window_factory)

    #: Default window size
    window_size = Tuple((800, 600))

    #: Currently active Window if any
    active_window = Instance(IWindow)

    #: List of all open windows in the application
    windows = List(Instance(IWindow))

    #: The Pyface GUI instance for the application
    gui = ReadOnly

    # Protected interface ----------------------------------------------------

    #: Flag if the exiting of the application was explicitely requested by user
    # An 'explicit' exit is when the 'exit' method is called.
    # An 'implicit' exit is when the user closes the last open window.
    _explicit_exit = Bool(False)

    # -------------------------------------------------------------------------
    # 'GUIApplication' interface
    # -------------------------------------------------------------------------

    # Window lifecycle methods -----------------------------------------------

    def create_window(self, **kwargs):
        """ Create a new application window.

        By default uses the :py:attr:`window_factory` to do this.  Subclasses
        can override if they want to do something different or additional.

        Parameters
        ----------
        **kwargs : dict
            Additional keyword arguments to pass to the window factory.

        Returns
        -------
        window : IWindow instance or None
            The new IWindow instance.
        """
        window = self.window_factory(application=self, **kwargs)

        if window.size == (-1, -1):
            window.size = self.window_size
        if not window.title:
            window.title = self.name
        if self.icon:
            window.icon = self.icon

        return window

    def add_window(self, window):
        """ Add a new window to the windows we are tracking. """

        # Keep a handle on all windows created so that non-active windows don't
        # get garbage collected
        self.windows.append(window)

        # Something might try to veto the opening of the window.
        opened = window.open()
        if opened:
            window.activate()

    # Action handlers --------------------------------------------------------

    def do_about(self):
        """ Display the about dialog, if it exists. """
        if self.about_dialog is not None:
            self.about_dialog.open()

    # -------------------------------------------------------------------------
    # 'Application' interface
    # -------------------------------------------------------------------------

    def start(self):
        """ Start the application, setting up things that are required

        Subclasses should open at least one ApplicationWindow or subclass in
        their start method, and should call the superclass start() method
        before doing any work themselves.
        """
        from pyface.gui import GUI

        ok = super(GUIApplication, self).start()
        if ok:
            # create the GUI so that the splash screen comes up first thing
            if self.gui is Undefined:
                self.gui = GUI(splash_screen=self.splash_screen)

            # create the initial windows to show
            self._create_windows()

        return ok

    # -------------------------------------------------------------------------
    # 'GUIApplication' Private interface
    # -------------------------------------------------------------------------

    def _create_windows(self):
        """ Create the initial windows to display.

        By default calls :py:meth:`create_window` once. Subclasses can
        override this method.
        """
        window = self.create_window()
        self.add_window(window)

    # -------------------------------------------------------------------------
    # 'Application' private interface
    # -------------------------------------------------------------------------

    def _run(self):
        """ Actual implementation of running the application: starting the GUI
        event loop.
        """
        # Fire a notification that the app is running.  This is guaranteed to
        # happen after all initialization has occurred and the event loop has
        # started.  A listener for this event is a good place to do things
        # where you want the event loop running.
        self.gui.invoke_later(self._fire_application_event,
                              'application_initialized')

        # start the GUI - script blocks here
        self.gui.start_event_loop()
        return True

    # Destruction methods -----------------------------------------------------

    def _can_exit(self):
        """ Check with each window to see if it can be closed

        The fires closing events for each window, and returns False if any
        listener vetos.
        """
        if not super(GUIApplication, self)._can_exit():
            return False

        for window in reversed(self.windows):
            window.closing = event = Vetoable()
            if event.veto:
                return False
        else:
            return True

    def _prepare_exit(self):
        """ Close each window """
        # ensure copy of list, as we modify original list while closing
        for window in list(reversed(self.windows)):
            window.destroy()
            window.closed = window

    def _exit(self):
        """ Shut down the event loop """
        self.gui.stop_event_loop()

    # Trait default handlers ------------------------------------------------

    def _window_factory_default(self):
        """ Default to ApplicationWindow

        This is almost never the right thing, but allows users to get off the
        ground with the base class.
        """
        from pyface.application_window import ApplicationWindow
        return lambda application, **kwargs: ApplicationWindow(**kwargs)

    def _splash_screen_default(self):
        """ Default SplashScreen """
        from pyface.splash_screen import SplashScreen

        dialog = SplashScreen()
        if self.logo:
            dialog.image = self.logo
        return dialog

    def _about_dialog_default(self):
        """ Default AboutDialog """
        from sys import version_info
        if (version_info.major, version_info.minor) >= (3, 2):
            from html import escape
        else:
            from cgi import escape
        from pyface.about_dialog import AboutDialog

        additions = [
            u"<h1>{}</h1>".format(escape(self.name)),
            u"Copyright &copy; 2018 {}, all rights reserved".format(
                escape(self.company), ),
            u"",
        ]
        additions += [escape(line) for line in self.description.split('\n\n')]

        dialog = AboutDialog(
            title=u"About {}".format(self.name),
            additions=additions,
        )
        if self.logo:
            dialog.image = self.logo
        return dialog

    # Trait listeners --------------------------------------------------------

    @on_trait_change('windows:activated')
    def _on_activate_window(self, window, trait, old, new):
        """ Listener that tracks currently active window.
        """
        if window in self.windows:
            self.active_window = window

    @on_trait_change('windows:deactivated')
    def _on_deactivate_window(self, window, trait, old, new):
        """ Listener that tracks currently active window.
        """
        self.active_window = None

    @on_trait_change('windows:closed')
    def _on_window_closed(self, window, trait, old, new):
        """ Listener that ensures window handles are released when closed.
        """
        if window in self.windows:
            self.windows.remove(window)
示例#19
0
class ToolPalette(Widget):

    tools = List

    id_tool_map = Dict

    tool_id_to_button_map = Dict

    button_size = Tuple((25, 25), Int, Int)

    is_realized = Bool(False)

    tool_listeners = Dict

    # Maps a button id to its tool id.
    button_tool_map = Dict

    ###########################################################################
    # 'object' interface.
    ###########################################################################

    def __init__(self, parent, **traits):
        """ Creates a new tool palette. """

        # Base class constructor.
        super(ToolPalette, self).__init__(**traits)

        # Create the toolkit-specific control that represents the widget.
        self.control = self._create_control(parent)

        return

    ###########################################################################
    # ToolPalette interface.
    ###########################################################################

    def add_tool(self, label, bmp, kind, tooltip, longtip):
        """ Add a tool with the specified properties to the palette.

        Return an id that can be used to reference this tool in the future.
        """

        wxid = wx.NewId()
        params = (wxid, label, bmp, kind, tooltip, longtip)
        self.tools.append(params)
        self.id_tool_map[wxid] = params

        if self.is_realized:
            self._reflow()

        return wxid

    def toggle_tool(self, id, checked):
        """ Toggle the tool identified by 'id' to the 'checked' state.

        If the button is a toggle or radio button, the button will be checked
        if the 'checked' parameter is True; unchecked otherwise.  If the button
        is a standard button, this method is a NOP.
        """

        button = self.tool_id_to_button_map.get(id, None)
        if button is not None and hasattr(button, 'SetToggle'):
            button.SetToggle(checked)

        return

    def enable_tool(self, id, enabled):
        """ Enable or disable the tool identified by 'id'. """

        button = self.tool_id_to_button_map.get(id, None)
        if button is not None:
            button.SetEnabled(enabled)

        return

    def on_tool_event(self, id, callback):
        """ Register a callback for events on the tool identified by 'id'. """

        callbacks = self.tool_listeners.setdefault(id, [])
        callbacks.append(callback)

        return

    def realize(self):
        """ Realize the control so that it can be displayed. """

        self.is_realized = True
        self._reflow()

        return

    def get_tool_state(self, id):
        """ Get the toggle state of the tool identified by 'id'. """

        button = self.tool_id_to_button_map.get(id, None)
        if hasattr(button, 'GetToggle'):
            if button.GetToggle():
                state = 1
            else:
                state = 0
        else:
            state = 0

        return state

    ###########################################################################
    # Private interface.
    ###########################################################################

    def _create_control(self, parent):

        html_window = wx.html.HtmlWindow(parent, -1, style=wx.CLIP_CHILDREN)

        return html_window

    def _reflow(self):
        """ Reflow the layout. """

        # Create a bit of html for each tool.
        parts = []
        for param in self.tools:
            parts.append(PART % (str(param[0]), self.button_size))

        # Create the entire html page.
        html = HTML % ''.join(parts)

        # Set the HTML on the widget.  This will create all of the buttons.
        self.control.SetPage(html)

        for param in self.tools:
            self._initialize_tool(param)

        return

    def _initialize_tool(self, param):
        """ Initialize the tool palette button. """

        wxid, label, bmp, kind, tooltip, longtip = param

        panel = self.control.FindWindowById(wxid)

        sizer = wx.BoxSizer(wx.VERTICAL)
        panel.SetSizer(sizer)
        panel.SetAutoLayout(True)
        panel.SetWindowStyleFlag(wx.CLIP_CHILDREN)

        from wx.lib.buttons import GenBitmapToggleButton, GenBitmapButton

        if kind == 'radio':
            button = GenBitmapToggleButton(panel,
                                           -1,
                                           None,
                                           size=self.button_size)

        else:
            button = GenBitmapButton(panel, -1, None, size=self.button_size)

        self.button_tool_map[button.GetId()] = wxid
        self.tool_id_to_button_map[wxid] = button
        wx.EVT_BUTTON(panel, button.GetId(), self._on_button)
        button.SetBitmapLabel(bmp)
        button.SetToolTipString(label)
        sizer.Add(button, 0, wx.EXPAND)

        return

    def _on_button(self, event):

        button_id = event.GetId()
        tool_id = self.button_tool_map.get(button_id, None)
        if tool_id is not None:
            for listener in self.tool_listeners.get(tool_id, []):
                listener(event)

        return
示例#20
0
class AttributeDragTool(ValueDragTool):
    """ Tool which modifies a model's attributes as it drags

    This is designed to cover the simplest of drag cases where the drag is
    modifying one or two numerical attributes on an underlying model.  To use,
    simply provide the model object and the attributes that you want to be
    changed by the drag.  If only one attribute is required, the other can be
    left as an empty string.
    """

    #: the model object which has the attributes we are modifying
    model = Any

    #: the name of the attributes that is modified by horizontal motion
    x_attr = Str

    #: the name of the attributes that is modified by vertical motion
    y_attr = Str

    #: max and min values for x value
    x_bounds = Tuple(Either(Float, Str, None), Either(Float, Str, None))

    #: max and min values for y value
    y_bounds = Tuple(Either(Float, Str, None), Either(Float, Str, None))

    x_name = Str

    y_name = Str

    # ValueDragTool API

    def get_value(self):
        """ Get the current value of the attributes

        Returns a 2-tuple of (x, y) values.  If either x_attr or y_attr is
        the empty string, then the corresponding component of the tuple is
        None.
        """
        x_value = None
        y_value = None
        if self.x_attr:
            x_value = getattr(self.model, self.x_attr)
        if self.y_attr:
            y_value = getattr(self.model, self.y_attr)
        return (x_value, y_value)

    def set_delta(self, value, delta_x, delta_y):
        """ Set the current value of the attributes

        Set the underlying attribute values based upon the starting value and
        the provided deltas.  The values are simply set to the sum of the
        appropriate coordinate and the delta. If either x_attr or y_attr is
        the empty string, then the corresponding component of is ignored.

        Note that setting x and y are two separate operations, and so will fire
        two trait notification events.
        """
        inspector_value = {}
        if self.x_attr:
            x_value = value[0] + delta_x
            if self.x_bounds[0] is not None:
                if isinstance(self.x_bounds[0], str):
                    m = getattr(self.model, self.x_bounds[0])
                else:
                    m = self.x_bounds[0]
                x_value = max(x_value, m)
            if self.x_bounds[1] is not None:
                if isinstance(self.x_bounds[1], str):
                    M = getattr(self.model, self.x_bounds[1])
                else:
                    M = self.x_bounds[1]
                x_value = min(x_value, M)
            setattr(self.model, self.x_attr, x_value)
            inspector_value[self.x_name] = x_value
        if self.y_attr:
            y_value = value[1] + delta_y
            if self.y_bounds[0] is not None:
                if isinstance(self.y_bounds[0], str):
                    m = getattr(self.model, self.y_bounds[0])
                else:
                    m = self.y_bounds[0]
                y_value = max(y_value, m)
            if self.y_bounds[1] is not None:
                if isinstance(self.y_bounds[1], str):
                    M = getattr(self.model, self.y_bounds[1])
                else:
                    M = self.y_bounds[1]
                y_value = min(y_value, M)
            setattr(self.model, self.y_attr, y_value)
            inspector_value[self.y_name] = y_value
        self.new_value = inspector_value

    def _x_name_default(self):
        return self.x_attr.replace("_", " ").capitalize()

    def _y_name_default(self):
        return self.y_attr.replace("_", " ").capitalize()
class BaseMCOOptionsModelView(ModelView):

    # -------------------
    # Required Attributes
    # -------------------

    #: Either a MCO KPI or parameter model
    model = Either(Instance(KPISpecification), Instance(BaseMCOParameter))

    #: Only display name and type options for available variables
    # FIXME: this isn't an ideal method, since it requires further
    # work arounds for the name validation. Putting better error
    # handling into the force_bdss could resolve this.
    available_variables = List(Tuple(Identifier, CUBAType))

    # ------------------
    # Regular Attributes
    # ------------------

    #: Defines if the KPI/parameter is valid or not. Updated by
    #: :func:`verify_tree
    #: <force_wfmanager.ui.setup.workflow_tree.WorkflowTree.verify_tree>`
    valid = Bool(True)

    #: An error message for issues in this modelview. Updated by
    #: :func:`verify_tree
    #: <force_wfmanager.ui.setup.workflow_tree.WorkflowTree.verify_tree>`
    error_message = Str()

    #: Event to request a verification check on the workflow
    #: Listens to: :attr:`model.name <model>` and :attr:`model.type <model>`
    verify_workflow_event = Event

    # ------------------
    #     Properties
    # ------------------

    #: Values for model.name EnumEditor in traits_view
    _combobox_values = Property(List(Identifier),
                                depends_on='available_variables')

    def __init__(self, model=None, *args, **kwargs):
        super(BaseMCOOptionsModelView, self).__init__(*args, **kwargs)
        if model is not None:
            self.model = model

    # ------------------
    #     Listeners
    # ------------------

    def _get__combobox_values(self):
        """Update combobox_values based on available variable names"""
        _combobox_values = []
        if self.available_variables is not None:
            for variable in self.available_variables:
                _combobox_values.append(variable[0])

        return _combobox_values

    @on_trait_change('model.+verify')
    def model_change(self):
        """Raise verify workflow event upon change in model"""
        self.verify_workflow_event = True

    @on_trait_change('model.name,_combobox_values')
    def _check_model_name(self):
        """Check the model name against a compiled list of available
         output variable names. Clear the model name if a matching
         output is not found"""
        if self.model is not None:
            if self._combobox_values is not None:
                if self.model.name not in self._combobox_values + ['']:
                    self.model.name = ''
示例#22
0
class ValueDragTool(DragTool):
    """ Abstract tool for modifying a value as the mouse is dragged

    The tool allows the use of an x_mapper and y_mapper to map between
    screen coordinates and more abstract data coordinates.  These mappers must
    be objects with a map_data() method that maps a component-space coordinate
    to a data-space coordinate.  Chaco mappers satisfy the required API, and
    the tool will look for 'x_mapper' and 'y_mapper' attributes on the
    component to use as the defaults, facilitating interoperability with Chaco
    plots. Failing that, a simple identity mapper is provided which does
    nothing. Coordinates are given relative to the component.

    Subclasses of this tool need to supply get_value() and set_delta() methods.
    The get_value method returns the current underlying value, while the
    set_delta method takes the current mapped x and y deltas from the original
    position, and sets the underlying value appropriately.  The object stores
    the original value at the start of the operation as the original_value
    attribute.
    """

    #: set of modifier keys that must be down to invoke the tool
    modifier_keys = Set(Enum(*keys))

    #: mapper that maps from horizontal screen coordinate to data coordinate
    x_mapper = Any

    #: mapper that maps from vertical screen coordinate to data coordinate
    y_mapper = Any

    #: start point of the drag in component coordinates
    original_screen_point = Tuple(Float, Float)

    #: start point of the drag in data coordinates
    original_data_point = Tuple(Any, Any)

    #: initial underlying value
    original_value = Any

    #: new_value event for inspector overlay
    new_value = Event(Dict)

    #: visibility for inspector overlay
    visible = Bool(False)

    def get_value(self):
        """ Return the current value that is being modified
        """
        pass

    def set_delta(self, value, delta_x, delta_y):
        """ Set the value that is being modified

        This function should modify the underlying value based on the provided
        delta_x and delta_y in data coordinates.  These deltas are total
        displacement from the original location, not incremental.  The value
        parameter is the original value at the point where the drag started.
        """
        pass

    # Drag tool API

    def drag_start(self, event):
        self.original_screen_point = (event.x, event.y)
        data_x = self.x_mapper.map_data(event.x)
        data_y = self.y_mapper.map_data(event.y)
        self.original_data_point = (data_x, data_y)
        self.original_value = self.get_value()
        self.visible = True
        return True

    def dragging(self, event):
        position = event.current_pointer_position()
        delta_x = (self.x_mapper.map_data(position[0]) -
                   self.original_data_point[0])
        delta_y = (self.y_mapper.map_data(position[1]) -
                   self.original_data_point[1])
        self.set_delta(self.original_value, delta_x, delta_y)
        return True

    def drag_end(self, event):
        event.window.set_pointer("arrow")
        self.visible = False
        return True

    def _drag_button_down(self, event):
        # override button down to handle modifier keys correctly
        if not event.handled and self._drag_state == "nondrag":
            key_states = dict((key, key in self.modifier_keys) for key in keys)
            if (not all(
                    getattr(event, key + "_down") == state
                    for key, state in key_states.items())):
                return False
            self.mouse_down_position = (event.x, event.y)
            if not self.is_draggable(*self.mouse_down_position):
                self._mouse_down_recieved = False
                return False
            self._mouse_down_received = True
            return True
        return False

    # traits default handlers

    def _x_mapper_default(self):
        # if the component has an x_mapper, try to use it by default
        return getattr(self.component, "x_mapper", identity_mapper)

    def _y_mapper_default(self):
        # if the component has an x_mapper, try to use it by default
        return getattr(self.component, "y_mapper", identity_mapper)
示例#23
0
class BaseFuture(HasStrictTraits):
    """
    Convenience base class for the various flavours of Future.
    """

    # IFuture interface #######################################################

    def cancel(self):
        """
        Request cancellation of the background task.

        A task in ``WAITING`` or ``EXECUTING`` state will immediately be moved
        to ``CANCELLING`` state. If the task is not in ``WAITING`` or
        ``EXECUTING`` state, this function does nothing.

        .. versionchanged:: 0.3.0

           This method no longer raises for a task that isn't cancellable.
           In previous versions, :exc:`RuntimeError` was raised.

        Returns
        -------
        cancelled : bool
            True if the task was cancelled, False if the task was not
            cancellable.
        """
        if self.state in {WAITING, EXECUTING}:
            self._user_cancelled()
            logger.debug(f"{self} cancelled")
            return True
        else:
            logger.debug(f"{self} not cancellable; state is {self.state}")
            return False

    def receive(self, message):
        """
        Receive and process a message from the task associated to this future.

        This method is primarily for use by the executors, but may also be of
        use in testing.

        Parameters
        ----------
        message : object
            The message received from the associated task.

        Returns
        -------
        final : bool
            True if the received message should be the last one ever received
            from the paired task.
        """
        message_type, message_arg = message
        method_name = "_task_{}".format(message_type)
        getattr(self, method_name)(message_arg)
        return message_type in FINAL_MESSAGES

    # BaseFuture interface ####################################################

    #: The state of the background task, to the best of the knowledge of
    #: this future. One of the six constants ``WAITING``, ``EXECUTING``,
    #: ``COMPLETED``, ``FAILED``, ``CANCELLING`` or ``CANCELLED``.
    state = Property(FutureState)

    #: True if cancellation of the background task can be requested, else
    #: False. Cancellation of the background task can be requested only if
    #: the future's ``state`` is either ``WAITING`` or ``EXECUTING``.
    cancellable = Property(Bool())

    #: True when communications from the background task are complete.
    #: At that point, no further state changes can occur for this future.
    #: This trait has value True if the ``state`` is one of ``COMPLETED``,
    #: ``FAILED``, or ``CANCELLED``. It's safe to listen to this trait
    #: for changes: it will always fire exactly once, and when it fires
    #: its value will be consistent with that of the ``state`` trait.
    done = Property(Bool())

    @property
    def result(self):
        """
        Result of the background task.

        This attribute is only available if the state of the future is
        ``COMPLETED``. If the future has not reached the ``COMPLETED`` state,
        any attempt to access this attribute will raise an ``AttributeError``.

        Returns
        -------
        result : object
            The result obtained from the background task.

        Raises
        ------
        AttributeError
            If the task is still executing, or was cancelled, or raised an
            exception instead of returning a result.
        """
        if self.state != COMPLETED:
            raise AttributeError(
                "No result available. Task has not yet completed, "
                "or was cancelled, or failed with an exception. "
                "Task state is {}".format(self.state))
        return self._result

    @property
    def exception(self):
        """
        Information about any exception raised by the background task.

        This attribute is only available if the state of this future is
        ``FAILED``. If the future has not reached the ``FAILED`` state, any
        attempt to access this attribute will raise an ``AttributeError.``

        Returns
        -------
        exc_info : tuple
            Tuple containing exception information in string form:
            (exception type, exception value, formatted traceback).

        Raises
        ------
        AttributeError
            If the task is still executing, or was cancelled, or completed
            without raising an exception.
        """
        if self.state != FAILED:
            raise AttributeError(
                "No exception information available. Task has "
                "not yet completed, or was cancelled, or completed "
                "without an exception. "
                "Task state is {}".format(self.state))
        return self._exception

    def dispatch(self, message):
        """
        Dispatch a message arriving from the associated BaseTask.

        This is a convenience function, and may be safely overridden by
        subclasses that want to use a different dispatch mechanism. For
        a message type ``msgtype``, it looks for a method called
        ``_process_<msgtype>`` and dispatches the message arguments to
        that method. Subclasses then only need to provide the appropriate
        ``_process_<msgtype>`` methods.

        Parameters
        ----------
        message : object
            Message sent by the background task. The default implementation of
            this method expects the message to be in the form ``(message_type,
            message_args)`` with ``message_type`` a string.
        """
        message_type, message_arg = message
        method_name = "_process_{}".format(message_type)
        getattr(self, method_name)(message_arg)

    # State transitions #######################################################

    # These methods represent state transitions in response to external events.

    def _task_sent(self, message):
        """
        Automate dispatch of different types of message.

        Delegates the actual work to the :meth:`dispatch` method,
        which can be overridden by subclasses. Messages received after
        cancellation are ignored.

        Parameters
        ----------
        message : object
            Message from the background task.
        """

        if self._internal_state == _CANCELLING_AFTER_STARTED:
            # Ignore messages that arrive after a cancellation request.
            return
        elif self._internal_state == EXECUTING:
            self.dispatch(message)
        else:
            raise _StateTransitionError(
                "Unexpected custom message in internal state {!r}".format(
                    self._internal_state))

    def _task_abandoned(self, none):
        """
        Update state when the background task is abandoned due to cancellation.

        Internal state:
        * _CANCELLING_BEFORE_STARTED -> _CANCELLED_ABANDONED

        Parameters
        ----------
        none : NoneType
            This parameter is unused.
        """
        if self._internal_state == _CANCELLING_BEFORE_STARTED:
            self._cancel = None
            self._internal_state = _CANCELLED_ABANDONED
        else:
            raise _StateTransitionError(
                "Unexpected 'started' message in internal state {!r}".format(
                    self._internal_state))

    def _task_started(self, none):
        """
        Update state when the background task has started processing.

        Internal state:
        * WAITING -> EXECUTING
        * _CANCELLING_BEFORE_STARTED -> _CANCELLED_AFTER_STARTED

        Parameters
        ----------
        none : NoneType
            This parameter is unused.
        """
        if self._internal_state == WAITING:
            self._internal_state = EXECUTING
        elif self._internal_state == _CANCELLING_BEFORE_STARTED:
            self._internal_state = _CANCELLING_AFTER_STARTED
        else:
            raise _StateTransitionError(
                "Unexpected 'started' message in internal state {!r}".format(
                    self._internal_state))

    def _task_returned(self, result):
        """
        Update state when background task reports completing successfully.

        Internal state:
        * EXECUTING -> COMPLETED
        * _CANCELLING_AFTER_STARTED -> _CANCELLED_COMPLETED

        Parameters
        ----------
        result : any
            The object returned by the background task.
        """
        if self._internal_state == EXECUTING:
            self._cancel = None
            self._result = result
            self._internal_state = COMPLETED
        elif self._internal_state == _CANCELLING_AFTER_STARTED:
            self._cancel = None
            self._result = result
            self._internal_state = _CANCELLED_COMPLETED
        else:
            raise _StateTransitionError(
                "Unexpected 'returned' message in internal state {!r}".format(
                    self._internal_state))

    def _task_raised(self, exception_info):
        """
        Update state when the background task reports completing with an error.

        Internal state:
        * EXECUTING -> FAILED
        * _CANCELLING_AFTER_STARTED -> _CANCELLED_FAILED

        Parameters
        ----------
        exception_info : tuple
            Tuple containing exception information in string form:
            (exception type, exception value, formatted traceback).
        """
        if self._internal_state == EXECUTING:
            self._cancel = None
            self._exception = exception_info
            self._internal_state = FAILED
        elif self._internal_state == _CANCELLING_AFTER_STARTED:
            self._cancel = None
            self._exception = exception_info
            self._internal_state = _CANCELLED_FAILED
        else:
            raise _StateTransitionError(
                "Unexpected 'raised' message in internal state {!r}".format(
                    self._internal_state))

    def _user_cancelled(self):
        """
        Update state when the user requests cancellation.

        A future in ``WAITING`` or ``EXECUTING`` state moves to ``CANCELLING``
        state.

        Internal state:
        * WAITING -> _CANCELLING_BEFORE_STARTED
        * EXECUTING -> _CANCELLING_AFTER_STARTED
        """
        if self._internal_state == WAITING:
            self._cancel()
            self._internal_state = _CANCELLING_BEFORE_STARTED
        elif self._internal_state == EXECUTING:
            self._cancel()
            self._internal_state = _CANCELLING_AFTER_STARTED
        else:
            raise _StateTransitionError(
                "Unexpected 'cancelled' message in internal state {!r}".format(
                    self._internal_state))

    # Private traits ##########################################################

    #: Callback called (with no arguments) when user requests cancellation.
    #: This is reset to ``None`` once cancellation is impossible.
    _cancel = Callable(allow_none=True)

    #: The internal state of the future.
    _internal_state = Enum(WAITING, list(_INTERNAL_STATE_TO_STATE))

    #: Result from the background task.
    _result = Any()

    #: Exception information from the background task.
    _exception = Tuple(Str(), Str(), Str())

    # Private methods #########################################################

    def _get_state(self):
        """Property getter for the "state" trait."""
        return _INTERNAL_STATE_TO_STATE[self._internal_state]

    def _get_cancellable(self):
        """Property getter for the "cancellable" trait."""
        return self._internal_state in _CANCELLABLE_INTERNAL_STATES

    def _get_done(self):
        """Property getter for the "done" trait."""
        return self._internal_state in _DONE_INTERNAL_STATES

    @observe("_internal_state")
    def _update_property_traits(self, event):
        """Trait change handler for the "_internal_state" trait."""
        old_internal_state, new_internal_state = event.old, event.new

        old_state = _INTERNAL_STATE_TO_STATE[old_internal_state]
        new_state = _INTERNAL_STATE_TO_STATE[new_internal_state]
        if old_state != new_state:
            self.trait_property_changed("state", old_state, new_state)

        old_cancellable = old_internal_state in _CANCELLABLE_INTERNAL_STATES
        new_cancellable = new_internal_state in _CANCELLABLE_INTERNAL_STATES
        if old_cancellable != new_cancellable:
            self.trait_property_changed("cancellable", old_cancellable,
                                        new_cancellable)

        old_done = old_internal_state in _DONE_INTERNAL_STATES
        new_done = new_internal_state in _DONE_INTERNAL_STATES
        if old_done != new_done:
            self.trait_property_changed("done", old_done, new_done)
示例#24
0
class ToolPaletteManager(ActionManager):
    """ A tool bar manager realizes itself in a tool palette bar control. """

    # 'ToolPaletteManager' interface ---------------------------------------

    # The size of tool images (width, height).
    image_size = Tuple((16, 16))

    # Should we display the name of each tool bar tool under its image?
    show_tool_names = Bool(True)

    # Private interface ----------------------------------------------------

    # Cache of tool images (scaled to the appropriate size).
    _image_cache = Instance(ImageCache)

    # ------------------------------------------------------------------------
    # 'object' interface.
    # ------------------------------------------------------------------------

    def __init__(self, *args, **traits):
        """ Creates a new tool bar manager. """

        # Base class contructor.
        super(ToolPaletteManager, self).__init__(*args, **traits)

        # An image cache to make sure that we only load each image used in the
        # tool bar exactly once.
        self._image_cache = ImageCache(self.image_size[0], self.image_size[1])

        return

    # ------------------------------------------------------------------------
    # 'ToolPaletteManager' interface.
    # ------------------------------------------------------------------------

    def create_tool_palette(self, parent, controller=None):
        """ Creates a tool bar. """

        # Create the control.
        tool_palette = ToolPalette(parent)

        # Add all of items in the manager's groups to the tool bar.
        self._add_tools(tool_palette, self.groups)

        self._set_initial_tool_state(tool_palette, self.groups)

        return tool_palette

    # ------------------------------------------------------------------------
    # Private interface.
    # ------------------------------------------------------------------------

    def _add_tools(self, tool_palette, groups):
        """ Adds tools for all items in a list of groups. """

        previous_non_empty_group = None
        for group in self.groups:
            if len(group.items) > 0:
                # Is a separator required?
                ## FIXME : Does the palette need the notion of a separator?
                ##                 if previous_non_empty_group is not None and group.separator:
                ##                     tool_bar.AddSeparator()
                ##
                ##                 previous_non_empty_group = group

                # Create a tool bar tool for each item in the group.
                for item in group.items:
                    control_id = item.add_to_palette(tool_palette,
                                                     self._image_cache,
                                                     self.show_tool_names)
                    item.control_id = control_id

        tool_palette.realize()

    def _set_initial_tool_state(self, tool_palette, groups):
        """ Workaround for the wxPython tool bar bug.

        Without this,  only the first item in a radio group can be selected
         when the tool bar is first realised 8^()

        """

        for group in groups:
            checked = False
            for item in group.items:
                # If the group is a radio group,  set the initial checked state
                # of every tool in it.
                if item.action.style == "radio":
                    tool_palette.toggle_tool(item.control_id,
                                             item.action.checked)
                    checked = checked or item.action.checked

                # Every item in a radio group MUST be 'radio' style, so we
                # can just skip to the next group.
                else:
                    break

            # We get here if the group is a radio group.
            else:
                # If none of the actions in the group is specified as 'checked'
                # we will check the first one.
                if not checked and len(group.items) > 0:
                    group.items[0].action.checked = True

        return
示例#25
0
class Cascade2D(HasStrictTraits):
    """
    Caution this has a bug.

    The idea here is to implement a cascade of convolvolution
    and modulus opertations.
    Suppose I had a sequence of wavelets, psi1, psi2, ...

    |x * psi1|
    |x * psi2| -> output
        .
        .
        .
     ---> ||x * psi1| * psi2|
          ||x * psi1| * psi3|
                .               -> output
                .
          ||x * psi2| * psi3|
                .
                .
                  |
                  |
                  ---> .. etc ..
    """

    #: Provides methods for decimating at each layer in the transform.
    decimation = Instance(IDecimationMethod, NoDecimation())

    #: Subsequent convolutions can be applied to downsampled images for
    #  efficiency.
    # Provide some options with Keras for this:
    # Max - MaxPooling (take max value in a window)
    # Average - AveragePooling (average values in window)
    pooling_type = Enum(["none", "max", "average"])

    # Stride - Set a stride when applying the convolutions:
    # Interpreted as "factor 2^n", to apply at each convolution.
    # each step, "0" gives a stride of 1 sample (the default).
    # 1 will apply convolutions at every second sample.
    # For now, negative numbers are not valid.
    stride_log2 = Int(0)

    # Size of the poolin to apply at each step, "0" means no pooling,
    # negative numbers will cause the output to be upsampled by that factor
    pooling_size = Int

    # shape of the input tile
    shape = Tuple(Int)

    #: In 2D we will apply the transform over a set of wavelets are different
    # orientation, define that here in degrees.
    angles = Tuple

    #: Direction to Keras Conv2d on how to do padding at each convolution,
    #  "same" pads with zeros, "valid" doesn't. This doesn't replace the
    #  need to pad tiles during preprocessing, however, "same" will maintain
    #  tile size through each layer.
    #  "valid" is faster.
    _padding = Enum(["same", "valid"])

    #: private, labels endpoints to attach things to
    _endpoint_counter = Int(0)

    #:
    _current_order = Int(1)

    def __init__(self, pooling_type, stride_log2, **traits):
        self.pooling_type = pooling_type.lower()
        if stride_log2 < 0:
            raise RuntimeError("stride_log2 needs to be > 0, we don't support \
                upsampling right now.")

        super().__init__(**traits)

    def _init_weights(self,
                      shape,
                      node=None,
                      dtype=None,
                      wavelet2d=None,
                      real_part=True):
        """
        Create an initializer for DepthwiseConv2D layers. We need these
        layers instead of Conv2D because we don't want it to stack across
        channels.

        Parameters
        ----------
        wavelet2d - IWavelet2D
            An object to create a wavelet.

        dtype - Float
            Data type for the wavelet, default is float32

        real_part - Bool
            If true it will initialize the convolutional weights
            with the real-part of the wavelet, if false, the
            imaginary part.

        Returns
        -------
        returns - tensorflow variable
            returns a tensorflow variable containing the weights.

        """
        if dtype is None:
            dtype = np.float32

        # precompute decimation
        wavelet_stride, conv_stride = self.decimation.resolve_scales(node)

        # we need to normalize by the decimation factor to preserve amplitude
        deci_norm = (wavelet_stride * conv_stride)**2

        # nx/ny is the image shape, num_inp/outp are the number of
        # channels inpit/output.
        nx, ny, num_inp, num_outp = shape

        if num_outp != len(self.angles):
            raise RuntimeError("weights: mismatch dimension num angles.")

        weights = np.zeros(shape, dtype=dtype)

        for iang, ang in enumerate(self.angles):
            wav = wavelet2d.kernel(ang) * deci_norm

            # decimate wavelet
            wav = self.decimation.decimate_wavelet(wav, wavelet_stride)

            # keras does 32-bit real number convolutions
            if real_part:
                x = wav.real.astype(np.float32)
            else:
                x = wav.imag.astype(np.float32)

            # apply to each input channel
            for ichan in range(shape[2]):
                weights[:, :, ichan, iang] = x[:shape[0], :shape[1]]

        return keras_backend.variable(value=weights, dtype=dtype)

    def _convolve_and_abs(self, wavelet, inp, node, trainable=False):
        """
        Implement the operations for |inp*psi|. Initially, there
        will be a channel for each angle defined in the cascade. For
        subsequent convolutions, a abs/conv operation is applied to
        each channel in the input, for each angle defined in the
        cascade.

        For example, if we have 3-angles defined in the cascade (angles)

        transform order | number of channels output
        ----------------------
        order 1         | 3-channels
        order 2         | 9-channels
        order 3         | 27-channels


        Parameters
        ----------
        wavelet - IWavelet2D
            A wavelet object used to generate weights for each angles,
            defined in self.angles.
        inp - Keras Layer
            A keras layer to apply the convolution to. For example,
            an input layer. Or subsequently the output of the previous
            convolutions.
        stride - Int
            Set a stride across the convolutions. This should be determined
            by the scale of the transform.
        node - Node
            Node in the tree.

        Returns
        -------
        returns - Keras Layer
            The result of the convolution and abs function.
        """
        # create a valid layer name
        name = re.sub("[*,.|_]", "", node.name)

        #
        wavelet_stride, conv_stride = self.decimation.resolve_scales(node)

        # after decimation
        wavelet_shape = (
            wavelet.shape[0] // wavelet_stride,
            wavelet.shape[1] // wavelet_stride,
        )

        square = Lambda(lambda x: keras_backend.square(x), trainable=False)
        add = Add(trainable=False)

        # The output gets a special name, because it's here we attach
        # things to. We name to the (endpoint)
        sqrt = Lambda(lambda x: keras_backend.sqrt(x),
                      trainable=False,
                      name=name)
        self._endpoint_counter += 1

        def real_init(*args, **kwargs):
            return self._init_weights(*args,
                                      node=node,
                                      real_part=True,
                                      wavelet2d=wavelet)

        def imag_init(*args, **kwargs):
            return self._init_weights(*args,
                                      node=node,
                                      real_part=False,
                                      wavelet2d=wavelet)

        real_part = DepthwiseConv2D(
            kernel_size=wavelet_shape,
            depth_multiplier=len(self.angles),
            data_format="channels_last",
            padding=self._padding,
            strides=conv_stride,
            trainable=trainable,
            depthwise_initializer=real_init,
        )(inp)
        real_part = square(real_part)

        imag_part = DepthwiseConv2D(
            kernel_size=wavelet_shape,
            depth_multiplier=len(self.angles),
            data_format="channels_last",
            padding=self._padding,
            strides=conv_stride,
            trainable=trainable,
            depthwise_initializer=imag_init,
        )(inp)
        imag_part = square(imag_part)

        result = add([real_part, imag_part])
        return sqrt(result)

    def _convolve(self, inp, psi, node):
        """
        This computes |inp*psi|.
        Which, for efficiency, (optionally) downsamples the output of the
        convolution.
        """
        # apply the conv_abs layers
        conv = self._convolve_and_abs(psi, inp, node)

        return conv

    def transform(self, cascade_tree, wavelets):
        """
        Apply abs/conv operations to arbitrary order.
        Doesn't apply the DC term, just the subsequent layers.

        Parameters
        ----------
        inp - Keras Layer
            The input at the root of the cascade. Would generally
            be a Keras Input Layer.

        Returns
        -------
        returns - Keras Model
            Returns a keras model applying the conv/abs operations
            of the scattering transform to the input.
        """
        cascade_tree.generate(wavelets, self._convolve)

        return cascade_tree.get_convolutions()
示例#26
0
class Cascade1D(HasStrictTraits):
    """
    Caution this has a bug.

    The idea here is to implement a cascade of convolvolution
    and modulus opertations.
    Suppose I had a sequence of wavelets, psi1, psi2, ...

    |x * psi1|
    |x * psi2| -> output
        .
        .
        .
     ---> ||x * psi1| * psi2|
          ||x * psi1| * psi3|
                .               -> output
                .
          ||x * psi2| * psi3|
                .
                .
                  |
                  |
                  ---> .. etc ..
    """

    #: Provides methods for decimating at each layer in the transform.
    decimation = Instance(IDecimationMethod, NoDecimation())

    #: Subsequent convolutions can be applied to downsampled images for
    #  efficiency.
    # Provide some options with Keras for this:
    # Max - MaxPooling (take max value in a window)
    # Average - AveragePooling (average values in window)
    pooling_type = Enum(["none", "max", "average"])

    # shape of the input tile
    shape = Tuple(Int)

    #: In 2D we will apply the transform over a set of wavelets are different
    # orientation, define that here in degrees.
    angles = Tuple

    #: Direction to Keras Conv2d on how to do padding at each convolution,
    #  "same" pads with zeros, "valid" doesn't. This doesn't replace the
    #  need to pad tiles during preprocessing, however, "same" will maintain
    #  tile size through each layer.
    #  "valid" is faster.
    _padding = Enum(["same", "valid"])

    #: private, labels endpoints to attach things to
    _endpoint_counter = Int(0)

    #:
    _current_order = Int(1)

    def _init_weights(self,
                      shape,
                      node=None,
                      dtype=None,
                      wavelet1d=None,
                      real_part=True):
        """
        Create an initializer for Conv1D layers.

        Parameters
        ----------
        wavelet1d - IWavelet2D
            An object to create a wavelet.

        dtype - Float
            Data type for the wavelet, default is float32

        real_part - Bool
            If true it will initialize the convolutional weights
            with the real-part of the wavelet, if false, the
            imaginary part.

        Returns
        -------
        returns - tensorflow variable
            returns a tensorflow variable containing the weights.
        """
        if dtype is None:
            dtype = np.float32

        # precompute decimation
        wavelet_stride, conv_stride = self.decimation.resolve_scales(node)

        # we need to normalize by the decimation factor to preserve amplitude
        deci_norm = (wavelet_stride * conv_stride)

        weights = np.zeros(shape, dtype=dtype)

        wav = wavelet1d.kernel() * deci_norm

        # decimate wavelet
        wav = self.decimation.decimate_wavelet(wav, wavelet_stride)

        # keras does 32-bit real number convolutions
        if real_part:
            x = wav.real.astype(np.float32)
        else:
            x = wav.imag.astype(np.float32)

        # apply to each input channel
        for ichan in range(shape[2]):
            weights[:, ichan, 0] = x[:shape[0]]

        return keras_backend.variable(value=weights, dtype=dtype)

    def _convolve_and_abs(self, wavelet, inp, node, trainable=False):
        """
        Implement the operations for |inp*psi| in 1-D. Assumes a single
        input channel. If you have multiple input channels you want
        the convolution to apply to, then to each independently.

        Parameters
        ----------
        wavelet - IWavelet1D
            A wavelet object used to generate weights for each angles,
            defined in self.angles.
        inp - Keras Layer
            A keras layer to apply the convolution to. For example,
            an input layer. Or subsequently the output of the previous
            convolutions.
        stride - Int
            Set a stride across the convolutions. This should be determined
            by the scale of the transform.
        node - Node
            Node in the tree.

        Returns
        -------
        returns - Keras Layer
            The result of the convolution and abs function.
        """

        # create a valid layer name
        name = re.sub("[*,.|_]", "", node.name)

        #
        wavelet_stride, conv_stride = self.decimation.resolve_scales(node)

        # after decimation
        wavelet_shape = (wavelet.shape[0] // wavelet_stride, )

        square = Lambda(lambda x: keras_backend.square(x), trainable=False)
        add = Add(trainable=False)

        # The output gets a special name, because it's here we attach
        # things to. We name to the (endpoint)
        sqrt = Lambda(lambda x: keras_backend.sqrt(x),
                      trainable=False,
                      name=name)
        self._endpoint_counter += 1

        # ensures proper alignment of subsequent convolutions
        if self._padding == "valid":
            _valid_align = int(wavelet_shape[0] // 2)
            inp = ReflectionPadding1D((_valid_align, _valid_align - 1))(inp)

        real_part = Conv1D(
            1,
            kernel_size=wavelet_shape,
            data_format="channels_last",
            padding=self._padding,
            strides=conv_stride,
            trainable=trainable,
            use_bias=False,
            kernel_initializer=lambda args: self._init_weights(
                args, node=node, real_part=True, wavelet1d=wavelet),
        )(inp)
        real_part = square(real_part)

        imag_part = Conv1D(
            1,
            kernel_size=wavelet_shape,
            data_format="channels_last",
            padding=self._padding,
            strides=conv_stride,
            trainable=trainable,
            use_bias=False,
            kernel_initializer=lambda args: self._init_weights(
                args, node=node, real_part=False, wavelet1d=wavelet),
        )(inp)
        imag_part = square(imag_part)

        result = add([real_part, imag_part])
        return sqrt(result)

    def _convolve(self, inp, psi, node):
        """
        This computes |inp*psi|.
        Which, for efficiency, (optionally) downsamples the output of the
        convolution.
        """
        # apply the conv_abs layers
        conv = self._convolve_and_abs(psi, inp, node)

        return conv

    def transform(self, cascade_tree, wavelets):
        """
        Apply abs/conv operations to arbitrary order.
        Doesn't apply the DC term, just the subsequent layers.

        Parameters
        ----------
        inp - Keras Layer
            The input at the root of the cascade. Would generally
            be a Keras Input Layer.

        Returns
        -------
        returns - Keras Model
            Returns a keras model applying the conv/abs operations
            of the scattering transform to the input.
        """
        cascade_tree.generate(wavelets, self._convolve)

        return cascade_tree.get_convolutions()
class SubdivisionDataMapper(AbstractDataMapper):
    """
    A data mapper that uses a uniform grid of rectangular cells. It doesn't make
    any assumptions about the continuity of the input data set, and explicitly
    stores each point in the data set in its cell.

    If the incoming data is ordered in some fashion such that most cells end
    up with large ranges of data, then it's better to use the
    SubdivisionLineDataMapper subclass.
    """
    celltype = Cell
    _last_region = List(Tuple)
    _cellgrid = Array  # a Numeric array of Cell objects
    _points_per_cell = Int(100)  # number of datapoints/cell to shoot for
    _cell_lefts = Array  # locations of left edge for all cells
    _cell_bottoms = Array  # locations of bottom edge for all cells
    _cell_extents = Tuple(Float, Float)  # the width and height of a cell

    #-------------------------------------------------------------------
    # Public AbstractDataMapper methods
    #-------------------------------------------------------------------

    def get_points_near(self, pointlist, radius=0.0):
        if radius != 0:
            # tmp is a list of list of arrays
            d = 2 * radius
            cell_points = [
                self.get_points_in_rect((px - radius, py - radius, d, d))
                for (px, py) in pointlist
            ]
        else:
            indices = self._get_indices_for_points(pointlist)
            cells = [self._cellgrid[i, j] for (i, j) in indices]
            self._last_region = self._cells_to_rects(indices)
            # unique-ify the list of cells
            cell_points = [c.get_points() for c in set(cells)]
        return vstack(cell_points)

    def get_points_in_rect(self, rect):
        x_span = (rect[0], rect[0] + rect[2])
        y_span = (rect[1], rect[1] + rect[3])
        min_i, max_i = searchsorted(self._cell_lefts, x_span) - 1
        min_j, max_j = searchsorted(self._cell_bottoms, y_span) - 1
        cellpts = [ self._cellgrid[i,j].get_points()
                        for i in range(min_i, max_i+1) \
                            for j in range(min_j, max_j+1) ]
        self._last_region = ( self._cell_lefts[min_i], self._cell_bottoms[min_j], \
                              (max_i - min_i + 1) * self._cell_extents[0], \
                              (max_j - min_j + 1) * self._cell_extents[1] )
        return vstack(cellpts)

    def get_last_region(self):
        return self._last_region

    #-------------------------------------------------------------------
    # AbstractDataMapper's abstract private methods
    #-------------------------------------------------------------------

    def _update_datamap(self):
        self._last_region = []
        # Create a new grid of the appropriate size, initialize it with new
        # Cell instance (of type self.celltype), and perform point insertion
        # on the new data.
        if self._data is None:
            self._cellgrid = array([], dtype=object)
            self._cell_lefts = array([])
            self._cell_bottoms = array([])
        else:
            num_x_cells, num_y_cells = self._calc_grid_dimensions()
            self._cellgrid = zeros((num_x_cells, num_y_cells), dtype=object)
            for i in range(num_x_cells):
                for j in range(num_y_cells):
                    self._cellgrid[i, j] = self.celltype(parent=self)
            ll, ur = self._extents
            cell_width = ur[0] / num_x_cells
            cell_height = ur[1] / num_y_cells

            # calculate the left and bottom edges of all the cells and store
            # them in two arrays
            self._cell_lefts = arange(ll[0],
                                      ll[0] + ur[0] - cell_width / 2,
                                      step=cell_width)
            self._cell_bottoms = arange(ll[1],
                                        ll[1] + ur[1] - cell_height / 2,
                                        step=cell_height)

            self._cell_extents = (cell_width, cell_height)

            # insert the data points
            self._basic_insertion(self.celltype)
        return

    def _clear(self):
        self._last_region = []
        self._cellgrid = []
        self._cell_lefts = []
        self._cell_bottoms = []
        self._cell_extents = (0, 0)
        return

    def _sort_order_changed(self, old, new):
        # since trait event notification only happens if the value has changed,
        # and there are only two types of sorting, it's safe to just reverse our
        # internal _data object
        self._data = self._data[::-1]
        for cell in self._cellgrid:
            # since cellgrid is a Numeric array, iterating over it produces
            # a length-1 array
            cell[0].reverse_indices()
        return

    #-------------------------------------------------------------------
    # helper private methods
    #-------------------------------------------------------------------

    def _calc_grid_dimensions(self):
        numpoints = self._data.shape[0]
        numcells = numpoints / self._points_per_cell
        ll, ur = self._extents
        aspect_ratio = (ur[0] - ll[0]) / (ur[1] - ll[1])
        num_y_cells = int(math.sqrt(numcells / aspect_ratio))
        num_x_cells = int(aspect_ratio * num_y_cells)
        if num_y_cells == 0:
            num_y_cells += 1
        if num_y_cells * num_x_cells * self._points_per_cell < numpoints:
            num_x_cells += 1
        return (num_x_cells, num_y_cells)

    def _basic_insertion(self, celltype):
        # generate a list of which cell each point in self._data belongs in
        cell_indices = self._get_indices_for_points(self._data)

        # We now look for ranges of points belonging to the same cell.
        # 1. shift lengthwise and difference; runs of cells with the same
        # (i,j) indices will be zero, and nonzero value for i or j will
        # indicate a transition to a new cell.  (Just like find_runs().)
        differences = cell_indices[1:] - cell_indices[:-1]

        # Since nonzero() only works for 1D arrays, we merge the X and Y columns
        # together to detect any point where either X or Y are nonzero.  We have
        # to add 1 because we shifted cell_indices before differencing (above).
        diff_indices = nonzero(differences[:, 0] + differences[:, 1])[0] + 1

        start_indices = concatenate([[0], diff_indices])
        end_indices = concatenate([diff_indices, [len(self._data)]])

        for start, end in zip(start_indices, end_indices):
            gridx, gridy = cell_indices[
                start]  # can use 'end' here just as well
            if celltype == RangedCell:
                self._cellgrid[gridx, gridy].add_ranges([(start, end)])
            else:
                self._cellgrid[gridx,
                               gridy].add_indices(list(range(start, end)))
        return

    def _get_indices_for_points(self, pointlist):
        """
        Given an input Nx2 array of points, returns a list Nx2 corresponding
        to the column and row indices into the cell grid.
        """
        x_array = searchsorted(self._cell_lefts, pointlist[:, 0]) - 1
        y_array = searchsorted(self._cell_bottoms, pointlist[:, 1]) - 1
        return array_zip(x_array, y_array)

    def _cells_to_rects(self, cells):
        """
        Converts the extents of a list of cell grid coordinates (i,j) into
        a list of rect tuples (x,y,w,h).  The set should be disjoint, but may
        or may not be minimal.
        """
        # Since this function is generally used to generate clipping regions
        # or other screen-related graphics, we should try to return large
        # rectangular blocks if possible.
        # For now, we just look for horizontal runs and return those.
        cells = array(cells)
        y_sorted = sort_points(cells, index=1)  # sort acoording to row
        rownums = sort(array(tuple(set(cells[:, 1]))))

        row_start_indices = searchsorted(y_sorted[:, 1], rownums)
        row_end_indices = left_shift(row_start_indices, len(cells))

        rects = []
        for rownum, start, end in zip(rownums, row_start_indices,
                                      row_end_indices):
            # y_sorted is sorted by the J (row) coordinate, so after we
            # extract the column indices, we need to sort them before
            # passing them to find_runs().
            grid_column_indices = sort(y_sorted[start:end][:, 0])
            #pdb.set_trace()
            #print grid_column_indices.shape
            for span in find_runs(grid_column_indices):
                x = self._cell_lefts[span[0]]
                y = self._cell_bottoms[rownum]
                w = (span[-1] - span[0] + 1) * self._cell_extents[0]
                h = self._cell_extents[1]
                rects.append((x, y, w, h))
        return rects
示例#28
0
class IScale(Interface):
    """
    An interface for various ways we could rescale flow data.
    
    Attributes
    ----------
    name : Str
        The name of this view (for serialization, UI, etc.)
        
    experiment : Instance(Experiment)
        The experiment this scale is to be applied to.  Needed because some
        scales have parameters estimated from data.
        
    channel : Str
        Which channel to scale.  Needed because some scales have parameters
        estimated from data.
        
    condition : Str
        What condition to scale.  Needed because some scales have parameters
        estimated from the a condition.  Must be a numeric condition; else
        instantiating the scale should fail.
        
    statistic : Tuple(Str, Str)
        What statistic to scale.  Needed because some scales have parameters
        estimated from a statistic.  The statistic must be numeric or an
        iterable of numerics; else instantiating the scale should fail.
        
    data : array_like
        What raw data to scale.
    """

    id = Str
    name = Str

    experiment = Instance("cytoflow.experiment.Experiment")

    # what are we using to parameterize the scale?  set one of these; if
    # multiple are set, the first is used.
    channel = Str
    condition = Str
    statistic = Tuple(Str, Str)
    error_statistic = Tuple(Str, Str)
    data = Array

    def __call__(self, data):
        """
        Transforms `data` using this scale.  Must know how to handle int, float,
        and lists, tuples, numpy.ndarrays and pandas.Series of int or float.
        Must return the same type passed.
        
        Careful!  May return `NaN` if the scale domain doesn't match the data 
        (ie, applying a log10 scale to negative numbers.
        """

    def inverse(self, data):
        """
        Transforms 'data' using the inverse of this scale.  Must know how to 
        handle int, float, and list, tuple, numpy.ndarray and pandas.Series of
        int or float.  Returns the same type as passed.
        """

    def clip(self, data):
        """
        Clips the data to the scale's domain.
        """

    def norm(self, vmin=None, vmax=None):
        """
示例#29
0
class MncaModel(HasRequiredTraits):

    #: Board size (X, Y)
    board_size = Tuple(Range(1, 500), Range(1, 500), required=True)

    #: % of living cells on reset
    reset_life_pct = Float(0.5)

    #: Board for the MNCA
    board = Array(shape=(None, None))

    #: Directory to parse masks from
    masks_dir = Directory(DEFAULT_MASKS_DIR, exists=True)

    #: Available masks (name to array)
    masks = Dict(Unicode, Array(shape=(None, None)))

    #: Rules
    rules = List(Instance(Rule), required=False)

    #: Pause updating the model
    paused = Bool(False)

    #: Drawing brush
    brush = Array(shape=(None, None), value=DEFAULT_BRUSH)

    live_color = Color("white")
    dead_color = Color("black")

    @on_trait_change("masks_dir")
    def set_masks(self):
        self.masks = load_masks(self.masks_dir)
        self.randomize_rules()

    def _masks_default(self):
        # TODO: this is a bit hacky to help with traits init of this class
        return load_masks(DEFAULT_MASKS_DIR)

    @on_trait_change("board_size")
    def reset_board(self):
        self.board = np.ones(self.board_size, dtype=int)
        self.board_reset()

    @on_trait_change("rules[]")
    def print_new_rules(self):
        print("----------")
        for rule in self.rules:
            print("mask='{rule.mask}', "
                  "acts_on={rule.acts_on!r}, "
                  "lower_limit={rule.lower_limit}, "
                  "upper_limit={rule.upper_limit}, "
                  "result={rule.result})".format(rule=rule))

    def randomize_rules(self):
        rules = []
        for i in range(random.randint(2, 10)):
            mask_name = random.choice(list(self.masks.keys()))
            mask = self.masks[mask_name]

            r_a, r_b = (random.randint(0, np.sum(mask)),
                        random.randint(0, (np.sum(mask))))
            lower = min([r_a, r_b])
            upper = max([r_a, r_b])

            acts_on = random.choice([0, 1, BOTH])

            result = random.choice([DEATH, LIFE])

            rules.append(
                Rule(mask=mask_name,
                     lower_limit=lower,
                     upper_limit=upper,
                     acts_on=acts_on,
                     result=result))

        self.rules = rules

    def board_reset(self):
        for i in range(self.board_size[0]):
            for j in range(self.board_size[1]):
                self.board[
                    i, j] = 1 if random.random() < self.reset_life_pct else 0

    def clear_board(self, value=0):
        self.board[:, :] = value

    def draw(self, target):
        """
        Draw on the board using the current brush at the target coordinates
        """
        for offset in np.transpose(np.where(self.brush)):
            offset -= np.array(self.brush.shape) // 2
            coord = target + offset
            if coord[0] < self.board.shape[0] and coord[1] < self.board.shape[
                    1]:
                self.board[coord[0], coord[1]] = 1

    def evolve_board(self):
        """
        Evolve the board one step according to the rules
        """
        if self.paused:
            return

        # TODO: update this in another thread if possible, to stop the UI from juttering

        gridmask = np.ones_like(self.board)
        convgrid = np.zeros_like(self.board)

        # Where the lower and upper bounds of the rule are satisfied
        rule1 = np.ones_like(self.board)
        rule2 = np.ones_like(self.board)

        for rule in self.rules:
            if not gridmask.any():
                break
            convolve(self.board,
                     self.masks[rule.mask],
                     mode="wrap",
                     output=convgrid)
            if rule.lower_limit is not None:
                rule1 = np.where(convgrid >= rule.lower_limit, 1, 0)
            else:
                rule1[:] = 1

            if rule.upper_limit is not None:
                rule2 = np.where(convgrid <= rule.upper_limit, 1, 0)
            else:
                rule2[:] = 1

            slc = (rule1 & rule2 & gridmask)

            if rule.acts_on != BOTH:
                acts_on = np.where(self.board == rule.acts_on, 1, 0)
                slc &= acts_on

            self.board[np.where(slc)] = 1 if rule.result == LIFE else 0
            gridmask[np.where(slc)] = 0
示例#30
0
class DataRange2D(BaseDataRange):
    """ A range on (2-D) image data.

    In a mathematically general sense, a 2-D range is an arbitrary region in
    the plane.  Arbitrary regions are difficult to implement well, so this
    class supports only rectangular regions for now.
    """

    # The actual value of the lower bound of this range. To set it, use
    # **low_setting**.
    low = Property  # (2,) array of lower-left x,y
    # The actual value of the upper bound of this range. To set it, use
    # **high_setting**.
    high = Property  # (2,) array of upper-right x,y

    # Property for the lower bound of this range (overrides AbstractDataRange).
    low_setting = Property
    # Property for the upper bound of this range (overrides AbstractDataRange).
    high_setting = Property

    # The 2-D grid range is actually implemented as two 1-D ranges, which can
    # be accessed individually.  They can also be set to new DataRange1D
    # instances; in that case, the DataRange2D's sources are removed from
    # its old 1-D dataranges and added to the new one.

    # Property for the range in the x-dimension.
    x_range = Property
    # Property for the range in the y-dimension.
    y_range = Property

    # Do "auto" bounds imply an exact fit to the data? (One Boolean per
    # dimension) If False, the bounds pad a little bit of margin on either
    # side.
    tight_bounds = Tuple(Bool(True), Bool(True))
    # The minimum percentage difference between low and high for each
    # dimension. That is, (high-low) >= epsilon * low.
    epsilon = Tuple(CFloat(1.0e-4), CFloat(1.0e-4))

    #------------------------------------------------------------------------
    # Private traits
    #------------------------------------------------------------------------

    # DataRange1D for the x-dimension.
    _xrange = Instance(DataRange1D, args=())
    # DataRange1D for the y-dimension.
    _yrange = Instance(DataRange1D, args=())

    #------------------------------------------------------------------------
    # AbstractRange interface
    #------------------------------------------------------------------------

    def clip_data(self, data):
        """ Returns a list of data values that are within the range.

        Implements AbstractDataRange.
        """
        return compress(self.mask_data(data), data, axis=0)

    def mask_data(self, data):
        """ Returns a mask array, indicating whether values in the given array
        are inside the range.

        Implements AbstractDataRange.
        """
        x_points, y_points = transpose(data)
        x_mask = (x_points >= self.low[0]) & (x_points <= self.high[0])
        y_mask = (y_points >= self.low[1]) & (y_points <= self.high[1])
        return x_mask & y_mask

    def bound_data(self, data):
        """ Not implemented for this class.
        """
        raise NotImplementedError("bound_data() has not been implemented "
                                  "for 2d pointsets.")

    def set_bounds(self, low, high):
        """ Sets all the bounds of the range simultaneously.

        Implements AbstractDataRange.

        Parameters
        ----------
        low : (x,y)
            Lower-left corner of the range.
        high : (x,y)
            Upper right corner of the range.
        """
        self._do_set_low_setting(low, fire_event=False)
        self._do_set_high_setting(high)

    #------------------------------------------------------------------------
    # Public methods
    #------------------------------------------------------------------------

    def __init__(self, *args, **kwargs):
        super(DataRange2D, self).__init__(*args, **kwargs)

    def reset(self):
        """ Resets the bounds of this range.
        """
        self.high_setting = ('auto', 'auto')
        self.low_setting = ('auto', 'auto')
        self.refresh()

    def refresh(self):
        """ If any of the bounds is 'auto', this method refreshes the actual
        low and high values from the set of the view filters' data sources.
        """
        if 'auto' not in self.low_setting and \
           'auto' not in self.high_setting:
            # If the user has hard-coded bounds, then refresh() doesn't do
            # anything.
            return
        else:
            self._refresh_bounds()

    #------------------------------------------------------------------------
    # Private methods
    #------------------------------------------------------------------------

    def _refresh_bounds(self):
        self._xrange.refresh()
        self._yrange.refresh()

    #------------------------------------------------------------------------
    # Property getters and setters
    #------------------------------------------------------------------------

    def _get_low(self):
        return (self._xrange.low, self._yrange.low)

    def _set_low(self, val):
        return self._set_low_setting(val)

    def _get_low_setting(self):
        return (self._xrange.low_setting, self._yrange.low_setting)

    def _set_low_setting(self, val):
        self._do_set_low_setting(val)

    def _do_set_low_setting(self, val, fire_event=True):
        self._xrange.low_setting = val[0]
        self._yrange.low_setting = val[1]

    def _get_high(self):
        return (self._xrange.high, self._yrange.high)

    def _set_high(self, val):
        return self._set_high_setting(val)

    def _get_high_setting(self):
        return (self._xrange.high_setting, self._yrange.high_setting)

    def _set_high_setting(self, val):
        self._do_set_high_setting(val)

    def _do_set_high_setting(self, val, fire_event=True):
        self._xrange.high_setting = val[0]
        self._yrange.high_setting = val[1]

    def _get_x_range(self):
        return self._xrange

    def _set_x_range(self, newrange):
        self._set_1d_range("_xdata", self._xrange, newrange)
        self._xrange = newrange

    def _get_y_range(self):
        return self._yrange

    def _set_y_range(self, newrange):
        self._set_1d_range("_ydata", self._yrange, newrange)
        self._yrange = newrange

    def _set_1d_range(self, dataname, oldrange, newrange):
        # dataname is the name of the underlying 1d data source of the
        # ImageData instances in self.sources, e.g. "_xdata" or "_ydata"
        for source in self.sources:
            source1d = getattr(source, dataname, None)
            if source1d:
                if oldrange:
                    oldrange.remove(source1d)
                if newrange:
                    newrange.add(source1d)
        return

    #------------------------------------------------------------------------
    # Event handlers
    #------------------------------------------------------------------------

    def _sources_items_changed(self, event):
        for source in event.removed:
            source.on_trait_change(self.refresh, "data_changed", remove=True)
        for source in event.added:
            source.on_trait_change(self.refresh, "data_changed")
        # the _xdata and _ydata of the sources may be created anew on every
        # access, so we can't just add/delete from _xrange and _yrange sources
        # based on object identity. So recreate lists each time:
        self._xrange.sources = [s._xdata for s in self.sources]
        self._yrange.sources = [s._ydata for s in self.sources]
        self.refresh()

    def _sources_changed(self, old, new):
        for source in old:
            source.on_trait_change(self.refresh, "data_changed", remove=True)
        for source in new:
            source.on_trait_change(self.refresh, "data_changed")
        # the _xdata and _ydata of the sources may be created anew on every
        # access, so we can't just add/delete from _xrange and _yrange sources
        # based on object identity. So recreate lists each time:
        self._xrange.sources = [s._xdata for s in self.sources]
        self._yrange.sources = [s._ydata for s in self.sources]
        self.refresh()

    @on_trait_change("_xrange.updated,_yrange.updated")
    def _subranges_updated(self):
        self.updated = True