コード例 #1
0
ファイル: test_delegate.py プロジェクト: skailasa/traits
class BazModify(HasTraits):
    foo = Instance(Foo, ())
    sd = Delegate("foo", prefix="s", modify=True)
    t = Delegate("foo", modify=True)
    u = Delegate("foo", listenable=False, modify=True)

    def _s_changed(self, name, old, new):
        # should never be called
        global baz_s_handler_self
        baz_s_handler_self = self
        return

    def _sd_changed(self, name, old, new):
        global baz_sd_handler_self
        baz_sd_handler_self = self
        return

    def _t_changed(self, name, old, new):
        global baz_t_handler_self
        baz_t_handler_self = self
        return

    def _u_changed(self, name, old, new):
        global baz_u_handler_self
        baz_u_handler_self = self
        return
コード例 #2
0
class IBVPSolve(HasTraits):
    ''' Manage the installation source tree
    DEPRECATED '''
    tloop = Instance(TLoop)

    def _tloop_default(self):
        ''' Default constructor'''
        return TLoop()

    tstepper = Delegate('tloop')
    rtrace_mngr = Delegate('tloop')

    view = View(Group(Item(name='tloop', style='custom', show_label=False),
                      label='Sim-Control'),
                Group(Item(name='tstepper', style='custom', show_label=False),
                      label='Sim-Model'),
                Group(Item(name='rtrace_mngr',
                           style='custom',
                           show_label=False),
                      label='Sim-Views'),
                title='IBVP-Solver',
                buttons=['OK'],
                resizable=True,
                scrollable=True,
                style='custom',
                x=0.,
                y=0.,
                width=.7,
                height=0.8)
コード例 #3
0
ファイル: test_delegate.py プロジェクト: skailasa/traits
class BazNoModify(HasTraits):
    foo = Instance(Foo, ())
    sd = Delegate("foo", prefix="s")
    t = Delegate("foo")
    u = Delegate("foo", listenable=False)

    def _s_changed(self, name, old, new):
        global baz_s_handler_self
        baz_s_handler_self = self
        return

    def _sd_changed(self, name, old, new):
        global baz_sd_handler_self
        baz_sd_handler_self = self
        return

    def _t_changed(self, name, old, new):
        global baz_t_handler_self
        baz_t_handler_self = self
        return

    def _u_changed(self, name, old, new):
        global baz_u_handler_self
        baz_u_handler_self = self
        return
コード例 #4
0
class AbstractCell(HasStrictTraits):
    """ Abstract class for grid cells in a uniform subdivision.

    Individual subclasses store points in different, possibly optimized
    fashion, and performance may be drastically different between different
    cell subclasses for a given set of data.
    """
    # The parent of this cell.
    parent = Instance(AbstractDataMapper)

    # The sort traits characterizes the internal points list.
    _sort_order = Delegate('parent')

    # The point array for this cell. This attribute delegates to parent._data,
    # which references the actual point array. For the sake of simplicity,
    # cells assume that _data is sorted in fashion indicated by **_sort_order**.
    # If this doesn't hold, then each cell needs to have its own duplicate
    # copy of the sorted data.
    data = Delegate('parent', '_data')

    # A list of indices into **data** that reflect the points inside this cell.
    indices = Property

    # Shadow trait for **indices**.
    _indices = Any

    def add_indices(self, indices):
        """ Adds a list of integer indices to the existing list of indices.
        """
        raise NotImplementedError

    def get_points(self):
        """ Returns a list of points that was previously set.

        This operation might be large and expensive; in general, use
        _get_indices() instead.
        """
        raise NotImplementedError

    def reverse_indices(self):
        """ Tells the cell to manipulate its indices so that they index to the
        same values in a reversed data array.

        Generally this method handles the situation when the parent's _data
        array has been flipped due to a sort order change.

        The length of _data must not have changed; otherwise there is no way to
        know the proper way to manipulate indices.
        """
        raise NotImplementedError

    def _set_indices(self, indices):
        raise NotImplementedError

    def _get_indices(self):
        """  Returns the list of indices into _data that reflect the points
        inside this cell.
        """
        raise NotImplementedError
コード例 #5
0
class CBPackage(CBPackageBase):

    #---------------------------------------------------------------------------
    #  Trait definitions:
    #---------------------------------------------------------------------------

    # Parent of this package:
    parent = Any

    # Name of file system path to this package:
    path = Property

    # Fully qualified name of the package (i.e. package.package...):
    full_name = Property

    # Name of the package:
    name = Str

    # Should methods be displayed:
    show_methods = Delegate('parent')

    # Can we use cached 'pyclbr' file?
    cache = Delegate('parent')

    # The text of the package (i.e. the '__init__.py' file):
    text = Property(Code)

    #---------------------------------------------------------------------------
    #  Implementation of the 'path' property:
    #---------------------------------------------------------------------------

    def _get_path(self):
        return join(self.parent.path, self.name)

    #---------------------------------------------------------------------------
    #  Implementation of the 'full_name' property:
    #---------------------------------------------------------------------------

    def _get_full_name(self):
        return '%s%s.' % (self.parent.full_name, self.name)

    #---------------------------------------------------------------------------
    #  Implementation of the 'text' property:
    #---------------------------------------------------------------------------

    def _get_text(self):
        if self._text is None:
            fh = open(join(self.path, '__init__.py'), 'rb')
            self._text = fh.read()
            fh.close()

        return self._text

    def _set_text(self):
        pass
コード例 #6
0
class SetActivePerspectiveAction(WorkbenchAction):
    """ An action that sets the active perspective. """

    #### 'Action' interface ###################################################

    # Is the action enabled?
    enabled = Delegate('perspective')

    # The action's unique identifier (may be None).
    id = Delegate('perspective')

    # The action's name (displayed on menus/tool bar tools etc).
    name = Delegate('perspective')

    # The action's style.
    style = 'radio'

    #### 'SetActivePerspectiveAction' interface ###############################

    # The perspective that we set the active perspective to.
    perspective = Instance(IPerspective)

    ###########################################################################
    # 'Action' interface.
    ###########################################################################

    def destroy(self):
        """ Destroy the action. """

        self.window = None

        return

    def perform(self, event):
        """ Perform the action. """

        self.window.active_perspective = self.perspective

        return

    ###########################################################################
    # Private interface.
    ###########################################################################

    @on_trait_change('perspective,window.active_perspective')
    def _refresh_checked(self):
        """ Refresh the checked state of the action. """

        self.checked = self.perspective is not None \
          and self.window is not None \
          and self.window.active_perspective is not None \
          and self.perspective.id is self.window.active_perspective.id

        return
コード例 #7
0
class SetActivePerspectiveAction(WorkbenchAction):
    """ An action that sets the active perspective. """

    # 'Action' interface ---------------------------------------------------

    # Is the action enabled?
    enabled = Delegate("perspective")

    # The action's unique identifier (may be None).
    id = Delegate("perspective")

    # The action's name (displayed on menus/tool bar tools etc).
    name = Delegate("perspective")

    # The action's style.
    style = "radio"

    # 'SetActivePerspectiveAction' interface -------------------------------

    # The perspective that we set the active perspective to.
    perspective = Instance(IPerspective)

    # ------------------------------------------------------------------------
    # 'Action' interface.
    # ------------------------------------------------------------------------

    def destroy(self):
        """ Destroy the action. """

        self.window = None

    def perform(self, event):
        """ Perform the action. """

        self.window.active_perspective = self.perspective

        return

    # ------------------------------------------------------------------------
    # Private interface.
    # ------------------------------------------------------------------------

    @observe("perspective,window.active_perspective")
    def _refresh_checked(self, event):
        """ Refresh the checked state of the action. """

        self.checked = (
            self.perspective is not None and self.window is not None
            and self.window.active_perspective is not None
            and self.perspective.id is self.window.active_perspective.id)

        return
コード例 #8
0
class RTraceEvalUDomainUnstructuredFieldVar(RTraceEval):
    fets_eval = WeakRef(IFETSEval)

    # @TODO Return the parametric coordinates of the element covering the element domain
    #
    vtk_r_arr = Delegate('fets_eval')
    field_entity_type = Delegate('fets_eval')
    dim_slice = Delegate('fets_eval')
    get_vtk_r_glb_arr = Delegate('fets_eval')
    n_vtk_r = Delegate('fets_eval')
    field_faces = Delegate('fets_eval')
    field_lines = Delegate('fets_eval')
    get_state_array_size = Delegate('fets_eval')
    n_vtk_cells = Delegate('fets_eval')
    vtk_cell_data = Delegate('fets_eval')
コード例 #9
0
class Child(HasTraits):

    mother = Instance(Parent)
    father = Instance(Parent)

    first_name = Str
    last_name = Delegate("father")
コード例 #10
0
class Foo(HasTraits):

    a = Any
    b = Bool
    s = Str
    i = Instance(HasTraits)
    e = Event
    d = Delegate("i")

    p = Property

    def _get_p(self):
        return self._p

    def _set_p(self, p):
        self._p = p

    # Read Only Property
    p_ro = Property

    def _get_p_ro(self):
        return id(self)

    # Write-only property
    p_wo = Property

    def _set_p_wo(self, p_wo):
        self._p_wo = p_wo
コード例 #11
0
class RTraceEvalElemFieldVar(RTraceEval):

    # To be specialized for element level
    #
    field_entity_type = Delegate('ts')
    vtk_r_arr = Delegate('ts')
    get_vtk_r_glb_arr = Delegate('ts')
    field_vertexes = Delegate('ts')
    field_lines = Delegate('ts')
    field_faces = Delegate('ts')
    field_volumes = Delegate('ts')
    n_vtk_cells = Delegate('ts')
    vtk_cell_data = Delegate('ts')
コード例 #12
0
class Child(HasTraits):
    age = Int

    father = Instance(Parent)

    last_name = Delegate('father')

    def _age_changed(self, old, new):
        print('Age changed from %s to %s ' % (old, new))
コード例 #13
0
ファイル: traits_ex.py プロジェクト: beijiaer/Visualization
class Child(HasTraits):
    age = Int
    #验证
    father = Instance(Parent)
    #代理
    last_name = Delegate("father")
    #监听
    def _age_changed(self, old, new):
        print("Age changed from %s to %s" % (old, new))
コード例 #14
0
class MPPanZoom(BaseTool):
    """ This tool wraps a pan and a zoom tool, and automatically switches
    behavior back and forth depending on how many blobs are tracked on
    screen.
    """

    pan = Instance(MPPanTool)

    zoom = Instance(MPDragZoom)

    event_state = Enum("normal", "pan", "zoom")

    _blobs = Delegate('zoom')
    _moves = Delegate('zoom')

    def _dispatch_stateful_event(self, event, suffix):
        self.zoom.dispatch(event, suffix)
        event.handled = False
        self.pan.dispatch(event, suffix)
        if len(self._blobs) == 2:
            self.event_state = 'zoom'
        elif len(self._blobs) == 1:
            self.event_state = 'pan'
        elif len(self._blobs) == 0:
            self.event_state = 'normal'
        else:
            assert len(self._blobs) <= 2
        if suffix == 'blob_up':
            event.window.release_blob(event.bid)
        elif suffix == 'blob_down':
            event.window.release_blob(event.bid)
            event.window.capture_blob(self, event.bid, event.net_transform())
            event.handled = True

    def _component_changed(self, old, new):
        self.pan.component = new
        self.zoom.component = new

    def _pan_default(self):
        return MPPanTool(self.component)

    def _zoom_default(self):
        return MPDragZoom(self.component)
コード例 #15
0
class Child(HasTraits):
    age = Int
    # 验证:father属性的值必须是Parent类的实例
    father = Instance(Parent)
    # 代理:Child实例的last_name属性代理给其father属性的last_name
    last_name = Delegate('father')

    # 监听:当age属性的值被修改时,下面的函数将被运行
    def _age_changed(self, old, new):
        print('age changed from %s to %s' % (old, new))
コード例 #16
0
class Child(HasTraits):
    age = Int
    
    # 验证: father属性的值必须是Parent类的实例
    father = Instance(Parent)
    
    # 委托: Child的实例的last_name属性委托给其father属性的last_name
    last_name = Delegate('father')
    
    # 监听: 当age属性的值被修改时,下面的函数将被运行
    def _age_changed(self, old, new):
        print("Age changed from %s to %s" %(old,new))
コード例 #17
0
class ToolbarButton(Button):

    toolbar = Any

    canvas = Delegate("toolbar")

    def __init__(self, *args, **kw):
        toolbar = kw.pop("toolbar", None)
        super().__init__(*args, **kw)
        if toolbar:
            self.toolbar = toolbar
            toolbar.add(self)
コード例 #18
0
class TimeInOut( SamplesGenerator ):
    """
    Base class for any time domain signal processing block, 
    gets samples from :attr:`source` and generates output via the 
    generator :meth:`result`
    """

    #: Data source; :class:`~acoular.sources.SamplesGenerator` or derived object.
    source = Trait(SamplesGenerator)

    #: Sampling frequency of output signal, as given by :attr:`source`.
    sample_freq = Delegate('source')
    
    #: Number of channels in output, as given by :attr:`source`.
    numchannels = Delegate('source')
               
    #: Number of samples in output, as given by :attr:`source`.
    numsamples = Delegate('source')
            
    # internal identifier
    digest = Property( depends_on = ['source.digest'])

    traits_view = View(
        Item('source', style='custom')
                    )

    @cached_property
    def _get_digest( self ):
        return digest(self)

    def result(self, num):
        """ 
        Python generator: dummy function, just echoes the output of source,
        yields samples in blocks of shape (num, :attr:`numchannels`), the last block
        may be shorter than num.
        """
        for temp in self.source.result(num):
            # effectively no processing
            yield temp
コード例 #19
0
class Child(HasTraits):

    age = Int

    # VALIDATION: 'father' must be a Parent instance:
    father = Instance(Parent)

    # DELEGATION: 'last_name' is delegated to father's 'last_name':
    last_name = Delegate('father')

    # NOTIFICATION: This method is called when 'age' changes:
    def _age_changed(self, old, new):
        print 'Age changed from %s to %s ' % (old, new)
コード例 #20
0
class VirtualDataName(HasPrivateTraits):

    # The TemplateDataName this is a virtual copy of:
    data_name = Instance(TemplateDataName)

    # The data name description:
    description = Delegate('data_name', modify=True)

    # The 'virtual' traits of this object:
    value0 = VirtualValue(index=0)
    value1 = VirtualValue(index=1)
    value2 = VirtualValue(index=2)
    value3 = VirtualValue(index=3)
    value4 = VirtualValue(index=4)
    value5 = VirtualValue(index=5)
コード例 #21
0
class Child(HasTraits):
    age = Int

    start_stop_capture = Button()
    view = View(Item('start_stop_capture', show_label=False))

    def _start_stop_capture_fired(self):
        if self.capture_thread and self.capture_thread.isAlive():
            self.capture_thread.wants_abort = True
        else:
            self.capture_thread = CaptureThread()
            self.capture_thread.wants_abort = False
            self.capture_thread.display = self.display
            self.capture_thread.start()

    # VALIDATION: 'father' must be a Parent instance:
    father = Instance(Parent)

    # DELEGATION: 'last_name' is delegated to father's 'last_name':
    last_name = Delegate('father')

    def eval_plot(self):
        y = []
        x = numpy.linspace(0, 1, 101)
        yi = 0.
        dyi = 1.
        k = self.age
        i = 0
        while i < len(x):
            y.append(yi)
            dyi_inc, yi_inc = euler_step(k, m, a, yi, dyi)
            dyi += dyi_inc
            yi += yi_inc
            i += 1
        plt.plot(x, numpy.asarray(y))
        plt.show()

    # NOTIFICATION: This method is called when 'age' changes:
    def _age_changed(self, old, new):
        print 'Age changed from %s to %s ' % (old, new)
        self.eval_plot()
コード例 #22
0
class Workbench(pyface.Workbench):
    """ The Envisage workbench.

    There is (usually) exactly *one* workbench per application. The workbench
    can create any number of workbench windows.

    """

    #### 'pyface.Workbench' interface #########################################

    # The factory that is used to create workbench windows.
    window_factory = WorkbenchWindow

    #### 'Workbench' interface ################################################

    # The application that the workbench is part of.
    application = Instance(IApplication)

    # Should the user be prompted before exiting the workbench?
    prompt_on_exit = Delegate("_preferences")

    #### Private interface ####################################################

    # The workbench preferences.
    _preferences = Instance(WorkbenchPreferences, ())

    ###########################################################################
    # Private interface.
    ###########################################################################

    def _exiting_changed(self, event):
        """ Called when the workbench is exiting. """

        if self.prompt_on_exit:
            answer = self.active_window.confirm(
                "Exit %s?" % self.active_window.title, "Confirm Exit")
            if answer != YES:
                event.veto = True

        return
コード例 #23
0
from uuid import uuid4

# Enthought library imports
from traits.api \
    import Any, Bool, Delegate, Enum, Float, Instance, Int, List, \
           Property, Str, Trait
from kiva.constants import FILL, STROKE

# Local relative imports
from .colors import black_color_trait, white_color_trait
from .coordinate_box import CoordinateBox
from .enable_traits import bounds_trait, coordinate_trait, LineStyle
from .interactor import Interactor

coordinate_delegate = Delegate("inner", modify=True)

DEFAULT_DRAWING_ORDER = [
    "background", "underlay", "mainlayer", "border", "overlay"
]


class Component(CoordinateBox, Interactor):
    """
    Component is the base class for most Enable objects.  In addition to the
    basic position and container features of Component, it also supports
    Viewports and has finite bounds.

    Since Components can have a border and padding, there is an additional set
    of bounds and position attributes that define the "outer box" of the
    components. These cannot be set, since they are secondary attributes
コード例 #24
0
class SubDOTSEval(TStepperEval):
    '''
    Domain with uniform FE-time-step-eval.
    '''
    dots_integ = Instance(ITStepperEval)

    new_cntl_var = Delegate('dots_integ')
    new_resp_var = Delegate('dots_integ')
    new_tangent_operator = Delegate('dots_integ')

    # The following operators should be run on each subdomain separately
    # the state array should be merged together from the several grids.
    state_array_size = Delegate('dots_integ')
    state_array = Delegate('dots_integ')
    ip_offset = Delegate('dots_integ')
    setup = Delegate('dots_integ')
    get_corr_pred = Delegate('dots_integ')
    map_u = Delegate('dots_integ')
    rte_dict = Delegate('dots_integ')
    get_vtk_cell_data = Delegate('dots_integ')
    get_vtk_X = Delegate('dots_integ')
    get_vtk_r_arr = Delegate('dots_integ')
    get_current_values = Delegate('dots_integ')
    get_vtk_pnt_ip_map = Delegate('dots_integ')

    sdomain = WeakRef

    debug = Bool(False)

    def apply_constraints(self, K):

        # Take care for kinematic compatibility between the subdomains and domain
        #
        # = Purpose =
        #
        # At this stage, the spatial domain has been refined - it contains the
        # list of registered refinements. These refinements have been added
        # during the problem setup or by the adaptive strategy.
        #
        # The manipulation of the domains is done using the DOTSList interface.
        # Within this interface, new refinement levels can be added with the backward
        # reference to the original level. The refinement levels provide the skeleton
        # for spatial refinement steps that is done incrementally by specifying the cells
        # of the coarse level to be moved/refined into the finer levels. The process may
        # run recursively.
        #
        # In this setup - run the loop over the refinement steps and impose
        # kinematic constraints between the parent and child domain levels.
        #
        parent_domain = self.sdomain.parent

        if parent_domain == None:
            return

        parent_fets_eval = parent_domain.fets_eval

        for p, fe_domain in self.sdomain.subgrids():

            dof_grid = fe_domain.dof_grid
            geo_grid = fe_domain.geo_grid
            if self.debug:
                print('parent')
                print(p)

            # Get the X coordinates from the parent !!!
            #
            # @todo: must define the dof_grid on the FEPatchedGrid
            parent_dofs = parent_domain.fe_subgrids[0][p].dofs[0].flatten()
            parent_points = parent_domain.fe_subgrids[0][p].dof_X[0]

            if self.debug:
                print('parent_dofs')
                print(parent_dofs)
                print('parent_points')
                print(parent_points)

            # Get the geometry approximation used in the super domain
            #
            N_geo_mtx = parent_fets_eval.get_N_geo_mtx

            # start vector for the search of local coordinate
            #
            lcenter = zeros(parent_points.shape[1], dtype='float_')

            # @todo - remove the [0] here - the N_geo_ntx should return an 1d array
            # to deliver simple coordinates instead of array of a single coordinate
            #
            def geo_approx(gpos, lpos):
                return dot(N_geo_mtx(lpos)[0], parent_points) - gpos

            # @todo use get_dNr_geo_mtx as fprime parameter
            #
            # geo_dapprox = ...

            # For each element in the grid evaluate the links.
            # Get the boundary dofs of the refinement.
            # Put the value into the ...
            #
            dofs, coords = dof_grid.get_boundary_dofs()
            for dofs, gpos in zip(dofs, coords):
                # find the pos within the parent domain at the position p
                #
                #
                # lpos = self.super_domain[pos].get_local_pos( pos )
                # N_mtx = self.super_domain.fets_eval.get_N_mtx( lpos )
                # K.register_constraint( a = dof, alpha = N, ix_a = super_dofs
                solution = fsolve(lambda lpos: geo_approx(gpos, lpos), lcenter)
                if isinstance(solution, float):
                    lpos = array([solution], dtype='float_')
                else:
                    lpos = solution

                if self.debug:
                    print('\tp', p, '\tdofs', dofs, '\tgpos', gpos, '\tlpos',
                          lpos)

                N_mtx = parent_fets_eval.get_N_mtx(lpos)

                if self.debug:
                    print('N_mtx')
                    print(N_mtx)

                for i, dof in enumerate(dofs):
                    K.register_constraint(a=dof,
                                          alpha=N_mtx[i],
                                          ix_a=parent_dofs)
コード例 #25
0
class Plot(DataView):
    """ Represents a correlated set of data, renderers, and axes in a single
    screen region.

    A Plot can reference an arbitrary amount of data and can have an
    unlimited number of renderers on it, but it has a single X-axis and a
    single Y-axis for all of its associated data. Therefore, there is a single
    range in X and Y, although there can be many different data series. A Plot
    also has a single set of grids and a single background layer for all of its
    renderers.  It cannot be split horizontally or vertically; to do so,
    create a VPlotContainer or HPlotContainer and put the Plots inside those.
    Plots can be overlaid as well; be sure to set the **bgcolor** of the
    overlaying plots to "none" or "transparent".

    A Plot consists of composable sub-plots.  Each of these is created
    or destroyed using the plot() or delplot() methods.  Every time that
    new data is used to drive these sub-plots, it is added to the Plot's
    list of data and data sources.  Data sources are reused whenever
    possible; in order to have the same actual array drive two de-coupled
    data sources, create those data sources before handing them to the Plot.
    """

    #------------------------------------------------------------------------
    # Data-related traits
    #------------------------------------------------------------------------

    # The PlotData instance that drives this plot.
    data = Instance(AbstractPlotData)

    # Mapping of data names from self.data to their respective datasources.
    datasources = Dict(Str, Instance(AbstractDataSource))

    #------------------------------------------------------------------------
    # General plotting traits
    #------------------------------------------------------------------------

    # Mapping of plot names to *lists* of plot renderers.
    plots = Dict(Str, List)

    # The default index to use when adding new subplots.
    default_index = Instance(AbstractDataSource)

    # Optional mapper for the color axis.  Not instantiated until first use;
    # destroyed if no color plots are on the plot.
    color_mapper = Instance(AbstractColormap)

    # List of colors to cycle through when auto-coloring is requested. Picked
    # and ordered to be red-green color-blind friendly, though should not
    # be an issue for blue-yellow.
    auto_colors = List(["green", "lightgreen", "blue", "lightblue", "red",
                        "pink", "darkgray", "silver"])

    # index into auto_colors list
    _auto_color_idx = Int(-1)
    _auto_edge_color_idx = Int(-1)
    _auto_face_color_idx = Int(-1)

    # Mapping of renderer type string to renderer class
    # This can be overriden to customize what renderer type the Plot
    # will instantiate for its various plotting methods.
    renderer_map = Dict(dict(line = LinePlot,
                             bar = BarPlot,
                             scatter = ScatterPlot,
                             polygon = PolygonPlot,
                             filled_line = FilledLinePlot,
                             cmap_scatter = ColormappedScatterPlot,
                             img_plot = ImagePlot,
                             cmap_img_plot = CMapImagePlot,
                             contour_line_plot = ContourLinePlot,
                             contour_poly_plot = ContourPolyPlot,
                             candle = CandlePlot,
                             quiver = QuiverPlot,))

    #------------------------------------------------------------------------
    # Annotations and decorations
    #------------------------------------------------------------------------

    # The title of the plot.
    title = Property()

    # The font to use for the title.
    title_font = Property()

    # Convenience attribute for title.overlay_position; can be "top",
    # "bottom", "left", or "right".
    title_position = Property()

    # Use delegates to expose the other PlotLabel attributes of the plot title
    title_text = Delegate("_title", prefix="text", modify=True)
    title_color = Delegate("_title", prefix="color", modify=True)
    title_angle = Delegate("_title", prefix="angle", modify=True)

    # The PlotLabel object that contains the title.
    _title = Instance(PlotLabel)

    # The legend on the plot.
    legend = Instance(Legend)

    # Convenience attribute for legend.align; can be "ur", "ul", "ll", "lr".
    legend_alignment = Property

    #------------------------------------------------------------------------
    # Public methods
    #------------------------------------------------------------------------

    def __init__(self, data=None, **kwtraits):
        if 'origin' in kwtraits:
            self.default_origin = kwtraits.pop('origin')
        if "title" in kwtraits:
            title = kwtraits.pop("title")
        else:
            title = None
        super(Plot, self).__init__(**kwtraits)
        if data is not None:
            if isinstance(data, AbstractPlotData):
                self.data = data
            elif type(data) in (ndarray, tuple, list):
                self.data = ArrayPlotData(data)
            else:
                raise ValueError, "Don't know how to create PlotData for data" \
                                  "of type " + str(type(data))

        if not self._title:
            self._title = PlotLabel(font="swiss 16", visible=False,
                                   overlay_position="top", component=self)
        if title is not None:
            self.title = title

        if not self.legend:
            self.legend = Legend(visible=False, align="ur", error_icon="blank",
                                 padding=10, component=self)

        # ensure that we only get displayed once by new_window()
        self._plot_ui_info = None

        return

    def add_xy_plot(self, index_name, value_name, renderer_factory, name=None,
        origin=None, **kwds):
        """ Add a BaseXYPlot renderer subclass to this Plot.

        Parameters
        ----------
        index_name : str
            The name of the index datasource.
        value_name : str
            The name of the value datasource.
        renderer_factory : callable
            The callable that creates the renderer.
        name : string (optional)
            The name of the plot.  If None, then a default one is created
            (usually "plotNNN").
        origin : string (optional)
            Which corner the origin of this plot should occupy:
                "bottom left", "top left", "bottom right", "top right"
        **kwds :
            Additional keywords to pass to the factory.
        """
        if name is None:
            name = self._make_new_plot_name()
        if origin is None:
            origin = self.default_origin
        index = self._get_or_create_datasource(index_name)
        self.index_range.add(index)
        value = self._get_or_create_datasource(value_name)
        self.value_range.add(value)

        if self.index_scale == "linear":
            imap = LinearMapper(range=self.index_range)
        else:
            imap = LogMapper(range=self.index_range)
        if self.value_scale == "linear":
            vmap = LinearMapper(range=self.value_range)
        else:
            vmap = LogMapper(range=self.value_range)

        renderer = renderer_factory(
            index = index,
            value = value,
            index_mapper = imap,
            value_mapper = vmap,
            orientation = self.orientation,
            origin = origin,
            **kwds
        )
        self.add(renderer)
        self.plots[name] = [renderer]
        self.invalidate_and_redraw()
        return self.plots[name]

    def plot(self, data, type="line", name=None, index_scale="linear",
             value_scale="linear", origin=None, **styles):
        """ Adds a new sub-plot using the given data and plot style.

        Parameters
        ----------
        data : string, tuple(string), list(string)
            The data to be plotted. The type of plot and the number of
            arguments determines how the arguments are interpreted:

            one item: (line/scatter)
                The data is treated as the value and self.default_index is
                used as the index.  If **default_index** does not exist, one is
                created from arange(len(*data*))
            two or more items: (line/scatter)
                Interpreted as (index, value1, value2, ...).  Each index,value
                pair forms a new plot of the type specified.
            two items: (cmap_scatter)
                Interpreted as (value, color_values).  Uses **default_index**.
            three or more items: (cmap_scatter)
                Interpreted as (index, val1, color_val1, val2, color_val2, ...)

        type : comma-delimited string of "line", "scatter", "cmap_scatter"
            The types of plots to add.
        name : string
            The name of the plot.  If None, then a default one is created
            (usually "plotNNN").
        index_scale : string
            The type of scale to use for the index axis. If not "linear", then
            a log scale is used.
        value_scale : string
            The type of scale to use for the value axis. If not "linear", then
            a log scale is used.
        origin : string
            Which corner the origin of this plot should occupy:
                "bottom left", "top left", "bottom right", "top right"
        styles : series of keyword arguments
            attributes and values that apply to one or more of the
            plot types requested, e.g.,'line_color' or 'line_width'.

        Examples
        --------
        ::

            plot("my_data", type="line", name="myplot", color=lightblue)

            plot(("x-data", "y-data"), type="scatter")

            plot(("x", "y1", "y2", "y3"))

        Returns
        -------
        [renderers] -> list of renderers created in response to this call to plot()
        """
        if len(data) == 0:
            return

        if isinstance(data, basestring):
            data = (data,)

        self.index_scale = index_scale
        self.value_scale = value_scale

        # TODO: support lists of plot types
        plot_type = type
        if name is None:
            name = self._make_new_plot_name()
        if origin is None:
            origin = self.default_origin

        if plot_type in ("line", "scatter", "polygon", "bar", "filled_line"):
            # Tie data to the index range
            if len(data) == 1:
                if self.default_index is None:
                    # Create the default index based on the length of the first
                    # data series
                    value = self._get_or_create_datasource(data[0])
                    self.default_index = ArrayDataSource(arange(len(value.get_data())),
                                                         sort_order="none")
                    self.index_range.add(self.default_index)
                index = self.default_index
            else:
                index = self._get_or_create_datasource(data[0])
                if self.default_index is None:
                    self.default_index = index
                self.index_range.add(index)
                data = data[1:]

            # Tie data to the value_range and create the renderer for each data
            new_plots = []
            simple_plot_types = ("line", "scatter")
            for value_name in data:
                value = self._get_or_create_datasource(value_name)
                self.value_range.add(value)
                if plot_type in simple_plot_types:
                    cls = self.renderer_map[plot_type]
                    # handle auto-coloring request
                    if styles.get("color") == "auto":
                        self._auto_color_idx = \
                            (self._auto_color_idx + 1) % len(self.auto_colors)
                        styles["color"] = self.auto_colors[self._auto_color_idx]
                elif plot_type in ("polygon", "filled_line"):
                    cls = self.renderer_map[plot_type]
                    # handle auto-coloring request
                    if styles.get("edge_color") == "auto":
                        self._auto_edge_color_idx = \
                            (self._auto_edge_color_idx + 1) % len(self.auto_colors)
                        styles["edge_color"] = self.auto_colors[self._auto_edge_color_idx]
                    if styles.get("face_color") == "auto":
                        self._auto_face_color_idx = \
                            (self._auto_face_color_idx + 1) % len(self.auto_colors)
                        styles["face_color"] = self.auto_colors[self._auto_face_color_idx]
                elif plot_type == 'bar':
                    cls = self.renderer_map[plot_type]
                    # handle auto-coloring request
                    if styles.get("color") == "auto":
                        self._auto_color_idx = \
                            (self._auto_color_idx + 1) % len(self.auto_colors)
                        styles["fill_color"] = self.auto_colors[self._auto_color_idx]
                else:
                    raise ValueError("Unhandled plot type: " + plot_type)

                if self.index_scale == "linear":
                    imap = LinearMapper(range=self.index_range,
                                stretch_data=self.index_mapper.stretch_data)
                else:
                    imap = LogMapper(range=self.index_range,
                                stretch_data=self.index_mapper.stretch_data)
                if self.value_scale == "linear":
                    vmap = LinearMapper(range=self.value_range,
                                stretch_data=self.value_mapper.stretch_data)
                else:
                    vmap = LogMapper(range=self.value_range,
                                stretch_data=self.value_mapper.stretch_data)

                plot = cls(index=index,
                           value=value,
                           index_mapper=imap,
                           value_mapper=vmap,
                           orientation=self.orientation,
                           origin = origin,
                           **styles)

                self.add(plot)
                new_plots.append(plot)

            if plot_type == 'bar':
                # For bar plots, compute the ranges from the data to make the
                # plot look clean.

                def custom_index_func(data_low, data_high, margin, tight_bounds):
                    """ Compute custom bounds of the plot along index (in
                    data space).
                    """
                    bar_width = styles.get('bar_width', cls().bar_width)
                    plot_low = data_low - bar_width
                    plot_high = data_high + bar_width
                    return plot_low, plot_high

                if self.index_range.bounds_func is None:
                    self.index_range.bounds_func = custom_index_func

                def custom_value_func(data_low, data_high, margin, tight_bounds):
                    """ Compute custom bounds of the plot along value (in
                    data space).
                    """
                    plot_low = data_low - (data_high-data_low)*0.1
                    plot_high = data_high + (data_high-data_low)*0.1
                    return plot_low, plot_high

                if self.value_range.bounds_func is None:
                    self.value_range.bounds_func = custom_value_func

                self.index_range.tight_bounds = False
                self.value_range.tight_bounds = False
                self.index_range.refresh()
                self.value_range.refresh()

            self.plots[name] = new_plots

        elif plot_type == "cmap_scatter":
            if len(data) != 3:
                raise ValueError("Colormapped scatter plots require (index, value, color) data")
            else:
                index = self._get_or_create_datasource(data[0])
                if self.default_index is None:
                    self.default_index = index
                self.index_range.add(index)
                value = self._get_or_create_datasource(data[1])
                self.value_range.add(value)
                color = self._get_or_create_datasource(data[2])
                if not styles.has_key("color_mapper"):
                    raise ValueError("Scalar 2D data requires a color_mapper.")

                colormap = styles.pop("color_mapper", None)

                if self.color_mapper is not None and self.color_mapper.range is not None:
                    color_range = self.color_mapper.range
                else:
                    color_range = DataRange1D()

                if isinstance(colormap, AbstractColormap):
                    self.color_mapper = colormap
                    if colormap.range is None:
                        color_range.add(color)
                        colormap.range = color_range

                elif callable(colormap):
                    color_range.add(color)
                    self.color_mapper = colormap(color_range)
                else:
                    raise ValueError("Unexpected colormap %r in plot()." % colormap)

                if self.index_scale == "linear":
                    imap = LinearMapper(range=self.index_range,
                                stretch_data=self.index_mapper.stretch_data)
                else:
                    imap = LogMapper(range=self.index_range,
                                stretch_data=self.index_mapper.stretch_data)
                if self.value_scale == "linear":
                    vmap = LinearMapper(range=self.value_range,
                                stretch_data=self.value_mapper.stretch_data)
                else:
                    vmap = LogMapper(range=self.value_range,
                                stretch_data=self.value_mapper.stretch_data)

                cls = self.renderer_map["cmap_scatter"]
                plot = cls(index=index,
                           index_mapper=imap,
                           value=value,
                           value_mapper=vmap,
                           color_data=color,
                           color_mapper=self.color_mapper,
                           orientation=self.orientation,
                           origin=origin,
                           **styles)
                self.add(plot)

            self.plots[name] = [plot]
        else:
            raise ValueError("Unknown plot type: " + plot_type)

        return self.plots[name]


    def img_plot(self, data, name=None, colormap=None,
                 xbounds=None, ybounds=None, origin=None, hide_grids=True, **styles):
        """ Adds image plots to this Plot object.

        If *data* has shape (N, M, 3) or (N, M, 4), then it is treated as RGB or
        RGBA (respectively) and *colormap* is ignored.

        If *data* is an array of floating-point data, then a colormap can
        be provided via the *colormap* argument, or the default of 'Spectral'
        will be used.

        *Data* should be in row-major order, so that xbounds corresponds to
        *data*'s second axis, and ybounds corresponds to the first axis.

        Parameters
        ----------
        data : string
            The name of the data array in self.plot_data
        name : string
            The name of the plot; if omitted, then a name is generated.
        xbounds, ybounds : string, tuple, or ndarray
            Bounds where this image resides. Bound may be: a) names of
            data in the plot data; b) tuples of (low, high) in data space,
            c) 1D arrays of values representing the pixel boundaries (must
            be 1 element larger than underlying data), or
            d) 2D arrays as obtained from a meshgrid operation
        origin : string
            Which corner the origin of this plot should occupy:
                "bottom left", "top left", "bottom right", "top right"
        hide_grids : bool, default True
            Whether or not to automatically hide the grid lines on the plot
        styles : series of keyword arguments
            Attributes and values that apply to one or more of the
            plot types requested, e.g.,'line_color' or 'line_width'.
        """
        if name is None:
            name = self._make_new_plot_name()
        if origin is None:
            origin = self.default_origin

        value = self._get_or_create_datasource(data)
        array_data = value.get_data()
        if len(array_data.shape) == 3:
            if array_data.shape[2] not in (3,4):
                raise ValueError("Image plots require color depth of 3 or 4.")
            cls = self.renderer_map["img_plot"]
            kwargs = dict(**styles)
        else:
            if colormap is None:
                if self.color_mapper is None:
                    colormap = Spectral(DataRange1D(value))
                else:
                    colormap = self.color_mapper
            elif isinstance(colormap, AbstractColormap):
                if colormap.range is None:
                    colormap.range = DataRange1D(value)
            else:
                colormap = colormap(DataRange1D(value))
            self.color_mapper = colormap
            cls = self.renderer_map["cmap_img_plot"]
            kwargs = dict(value_mapper=colormap, **styles)
        return self._create_2d_plot(cls, name, origin, xbounds, ybounds, value,
                                    hide_grids, **kwargs)


    def contour_plot(self, data, type="line", name=None, poly_cmap=None,
                     xbounds=None, ybounds=None, origin=None, hide_grids=True, **styles):
        """ Adds contour plots to this Plot object.

        Parameters
        ----------
        data : string
            The name of the data array in self.plot_data, which must be
            floating point data.
        type : comma-delimited string of "line", "poly"
            The type of contour plot to add. If the value is "poly"
            and no colormap is provided via the *poly_cmap* argument, then
            a default colormap of 'Spectral' is used.
        name : string
            The name of the plot; if omitted, then a name is generated.
        poly_cmap : string
            The name of the color-map function to call (in
            chaco.default_colormaps) or an AbstractColormap instance
            to use for contour poly plots (ignored for contour line plots)
        xbounds, ybounds : string, tuple, or ndarray
            Bounds where this image resides. Bound may be: a) names of
            data in the plot data; b) tuples of (low, high) in data space,
            c) 1D arrays of values representing the pixel boundaries (must
            be 1 element larger than underlying data), or
            d) 2D arrays as obtained from a meshgrid operation
        origin : string
            Which corner the origin of this plot should occupy:
                "bottom left", "top left", "bottom right", "top right"
        hide_grids : bool, default True
            Whether or not to automatically hide the grid lines on the plot
        styles : series of keyword arguments
            Attributes and values that apply to one or more of the
            plot types requested, e.g.,'line_color' or 'line_width'.
        """
        if name is None:
            name = self._make_new_plot_name()
        if origin is None:
            origin = self.default_origin

        value = self._get_or_create_datasource(data)
        if value.value_depth != 1:
            raise ValueError("Contour plots require 2D scalar field")
        if type == "line":
            cls = self.renderer_map["contour_line_plot"]
            kwargs = dict(**styles)
            # if colors is given as a factory func, use it to make a
            # concrete colormapper. Better way to do this?
            if "colors" in kwargs:
                cmap = kwargs["colors"]
                if isinstance(cmap, FunctionType):
                    kwargs["colors"] = cmap(DataRange1D(value))
                elif getattr(cmap, 'range', 'dummy') is None:
                    cmap.range = DataRange1D(value)
        elif type == "poly":
            if poly_cmap is None:
                poly_cmap = Spectral(DataRange1D(value))
            elif isinstance(poly_cmap, FunctionType):
                poly_cmap = poly_cmap(DataRange1D(value))
            elif getattr(poly_cmap, 'range', 'dummy') is None:
                poly_cmap.range = DataRange1D(value)
            cls = self.renderer_map["contour_poly_plot"]
            kwargs = dict(color_mapper=poly_cmap, **styles)
        else:
            raise ValueError("Unhandled contour plot type: " + type)

        return self._create_2d_plot(cls, name, origin, xbounds, ybounds, value,
                                    hide_grids, **kwargs)


    def _process_2d_bounds(self, bounds, array_data, axis):
        """Transform an arbitrary bounds definition into a linspace.

        Process all the ways the user could have defined the x- or y-bounds
        of a 2d plot and return a linspace between the lower and upper
        range of the bounds.

        Parameters
        ----------
        bounds : any
            User bounds definition

        array_data : 2D array
            The 2D plot data

        axis : int
            The axis along which the bounds are to be set
        """

        num_ticks = array_data.shape[axis] + 1

        if bounds is None:
            return arange(num_ticks)

        if type(bounds) is tuple:
            # create a linspace with the bounds limits
            return linspace(bounds[0], bounds[1], num_ticks)

        if type(bounds) is ndarray and len(bounds.shape) == 1:
            # bounds is 1D, but of the wrong size

            if len(bounds) != num_ticks:
                msg = ("1D bounds of an image plot needs to have 1 more "
                       "element than its corresponding data shape, because "
                       "they represent the locations of pixel boundaries.")
                raise ValueError(msg)
            else:
                return linspace(bounds[0], bounds[-1], num_ticks)

        if type(bounds) is ndarray and len(bounds.shape) == 2:
            # bounds is 2D, assumed to be a meshgrid
            # This is triggered when doing something like
            # >>> xbounds, ybounds = meshgrid(...)
            # >>> z = f(xbounds, ybounds)

            if bounds.shape != array_data.shape:
                msg = ("2D bounds of an image plot needs to have the same "
                       "shape as the underlying data, because "
                       "they are assumed to be generated from meshgrids.")
                raise ValueError(msg)
            else:
                if axis == 0: bounds = bounds[:,0]
                else: bounds = bounds[0,:]
                interval = bounds[1] - bounds[0]
                return linspace(bounds[0], bounds[-1]+interval, num_ticks)

        raise ValueError("bounds must be None, a tuple, an array, "
                         "or a PlotData name")


    def _create_2d_plot(self, cls, name, origin, xbounds, ybounds, value_ds,
                        hide_grids, **kwargs):
        if name is None:
            name = self._make_new_plot_name()
        if origin is None:
            origin = self.default_origin

        array_data = value_ds.get_data()

        # process bounds to get linspaces
        if isinstance(xbounds, basestring):
            xbounds = self._get_or_create_datasource(xbounds).get_data()

        xs = self._process_2d_bounds(xbounds, array_data, 1)

        if isinstance(ybounds, basestring):
            ybounds = self._get_or_create_datasource(ybounds).get_data()

        ys = self._process_2d_bounds(ybounds, array_data, 0)

        # Create the index and add its datasources to the appropriate ranges
        index = GridDataSource(xs, ys, sort_order=('ascending', 'ascending'))
        self.range2d.add(index)
        mapper = GridMapper(range=self.range2d,
                            stretch_data_x=self.x_mapper.stretch_data,
                            stretch_data_y=self.y_mapper.stretch_data)

        plot = cls(index=index,
                   value=value_ds,
                   index_mapper=mapper,
                   orientation=self.orientation,
                   origin=origin,
                   **kwargs)

        if hide_grids:
            self.x_grid.visible = False
            self.y_grid.visible = False

        self.add(plot)
        self.plots[name] = [plot]
        return self.plots[name]


    def candle_plot(self, data, name=None, value_scale="linear", origin=None,
                    **styles):
        """ Adds a new sub-plot using the given data and plot style.

        Parameters
        ----------
        data : list(string), tuple(string)
            The names of the data to be plotted in the ArrayDataSource.  The
            number of arguments determines how they are interpreted:

            (index, bar_min, bar_max)
                filled or outline-only bar extending from **bar_min** to
                **bar_max**

            (index, bar_min, center, bar_max)
                above, plus a center line of a different color at **center**

            (index, min, bar_min, bar_max, max)
                bar extending from **bar_min** to **bar_max**, with thin
                bars at **min** and **max** connected to the bar by a long
                stem

            (index, min, bar_min, center, bar_max, max)
                like above, plus a center line of a different color and
                configurable thickness at **center**

        name : string
            The name of the plot.  If None, then a default one is created.

        value_scale : string
            The type of scale to use for the value axis.  If not "linear",
            then a log scale is used.

        Styles
        ------
        These are all optional keyword arguments.

        bar_color : string, 3- or 4-tuple
            The fill color of the bar; defaults to "auto".
        bar_line_color : string, 3- or 4-tuple
            The color of the rectangular box forming the bar.
        stem_color : string, 3- or 4-tuple (default = bar_line_color)
            The color of the stems reaching from the bar to the min and
            max values.
        center_color : string, 3- or 4-tuple (default = bar_line_color)
            The color of the line drawn across the bar at the center values.
        line_width : int (default = 1)
            The thickness, in pixels, of the outline around the bar.
        stem_width : int (default = line_width)
            The thickness, in pixels, of the stem lines
        center_width : int (default = line_width)
            The width, in pixels, of the line drawn across the bar at the
            center values.
        end_cap : bool (default = True)
            Whether or not to draw bars at the min and max extents of the
            error bar.

        Returns
        -------
        [renderers] -> list of renderers created in response to this call.
        """
        if len(data) == 0:
            return
        self.value_scale = value_scale

        if name is None:
            name = self._make_new_plot_name()
        if origin is None:
            origin = self.default_origin

        # Create the datasources
        if len(data) == 3:
            index, bar_min, bar_max = map(self._get_or_create_datasource, data)
            self.value_range.add(bar_min, bar_max)
            center = None
            min = None
            max = None
        elif len(data) == 4:
            index, bar_min, center, bar_max = map(self._get_or_create_datasource, data)
            self.value_range.add(bar_min, center, bar_max)
            min = None
            max = None
        elif len(data) == 5:
            index, min, bar_min, bar_max, max = \
                            map(self._get_or_create_datasource, data)
            self.value_range.add(min, bar_min, bar_max, max)
            center = None
        elif len(data) == 6:
            index, min, bar_min, center, bar_max, max = \
                            map(self._get_or_create_datasource, data)
            self.value_range.add(min, bar_min, center, bar_max, max)
        self.index_range.add(index)

        if styles.get("bar_color") == "auto" or styles.get("color") == "auto":
            self._auto_color_idx = \
                (self._auto_color_idx + 1) % len(self.auto_colors)
            styles["color"] = self.auto_colors[self._auto_color_idx]

        if self.index_scale == "linear":
            imap = LinearMapper(range=self.index_range,
                        stretch_data=self.index_mapper.stretch_data)
        else:
            imap = LogMapper(range=self.index_range,
                        stretch_data=self.index_mapper.stretch_data)
        if self.value_scale == "linear":
            vmap = LinearMapper(range=self.value_range,
                        stretch_data=self.value_mapper.stretch_data)
        else:
            vmap = LogMapper(range=self.value_range,
                        stretch_data=self.value_mapper.stretch_data)

        cls = self.renderer_map["candle"]
        plot = cls(index = index,
                          min_values = min,
                          bar_min = bar_min,
                          center_values = center,
                          bar_max = bar_max,
                          max_values = max,
                          index_mapper = imap,
                          value_mapper = vmap,
                          orientation = self.orientation,
                          origin = self.origin,
                          **styles)
        self.add(plot)
        self.plots[name] = [plot]
        return [plot]

    def quiverplot(self, data, name=None, origin=None,
                    **styles):
        """ Adds a new sub-plot using the given data and plot style.

        Parameters
        ----------
        data : list(string), tuple(string)
            The names of the data to be plotted in the ArrayDataSource.  There
            is only one combination accepted by this function:

            (index, value, vectors)
                index and value together determine the start coordinates of
                each vector.  The vectors are an Nx2

        name : string
            The name of the plot.  If None, then a default one is created.

        origin : string
            Which corner the origin of this plot should occupy:
                "bottom left", "top left", "bottom right", "top right"

        Styles
        ------
        These are all optional keyword arguments.

        line_color : string (default = "black")
            The color of the arrows
        line_width : float (default = 1.0)
            The thickness, in pixels, of the arrows.
        arrow_size : int (default = 5)
            The length, in pixels, of the arrowhead

        Returns
        -------
        [renderers] -> list of renderers created in response to this call.
        """
        if name is None:
            name = self._make_new_plot_name()
        if origin is None:
            origin = self.default_origin

        index, value, vectors = map(self._get_or_create_datasource, data)

        self.index_range.add(index)
        self.value_range.add(value)

        imap = LinearMapper(range=self.index_range,
                            stretch_data=self.index_mapper.stretch_data)
        vmap = LinearMapper(range=self.value_range,
                            stretch_data=self.value_mapper.stretch_data)

        cls = self.renderer_map["quiver"]
        plot = cls(index = index,
                   value = value,
                   vectors = vectors,
                   index_mapper = imap,
                   value_mapper = vmap,
                   name = name,
                   origin = origin,
                   **styles
                   )
        self.add(plot)
        self.plots[name] = [plot]
        return [plot]

    def delplot(self, *names):
        """ Removes the named sub-plots. """

        # This process involves removing the plots, then checking the index range
        # and value range for leftover datasources, and removing those if necessary.

        # Remove all the renderers from us (container) and create a set of the
        # datasources that we might have to remove from the ranges
        deleted_sources = set()
        for renderer in itertools.chain(*[self.plots.pop(name) for name in names]):
            self.remove(renderer)
            deleted_sources.add(renderer.index)
            deleted_sources.add(renderer.value)

        # Cull the candidate list of sources to remove by checking the other plots
        sources_in_use = set()
        for p in itertools.chain(*self.plots.values()):
                sources_in_use.add(p.index)
                sources_in_use.add(p.value)

        unused_sources = deleted_sources - sources_in_use - set([None])

        # Remove the unused sources from all ranges
        for source in unused_sources:
            if source.index_dimension == "scalar":
                # Try both index and range, it doesn't hurt
                self.index_range.remove(source)
                self.value_range.remove(source)
            elif source.index_dimension == "image":
                self.range2d.remove(source)
            else:
                warnings.warn("Couldn't remove datasource from datarange.")

        return

    def hideplot(self, *names):
        """ Convenience function to sets the named plots to be invisible.  Their
        renderers are not removed, and they are still in the list of plots.
        """
        for renderer in itertools.chain(*[self.plots[name] for name in names]):
            renderer.visible = False
        return

    def showplot(self, *names):
        """ Convenience function to sets the named plots to be visible.
        """
        for renderer in itertools.chain(*[self.plots[name] for name in names]):
            renderer.visible = True
        return

    def new_window(self, configure=False):
        """Convenience function that creates a window containing the Plot

        Don't call this if the plot is already displayed in a window.
        """
        from chaco.ui.plot_window import PlotWindow
        if self._plot_ui_info is None:
            if configure:
                self._plot_ui_info = PlotWindow(plot=self).configure_traits()
            else:
                self._plot_ui_info = PlotWindow(plot=self).edit_traits()
        return self._plot_ui_info

    #------------------------------------------------------------------------
    # Private methods
    #------------------------------------------------------------------------



    def _make_new_plot_name(self):
        """ Returns a string that is not already used as a plot title.
        """
        n = len(self.plots)
        plot_template = "plot%d"
        while 1:
            name = plot_template % n
            if name not in self.plots:
                break
            else:
                n += 1
        return name

    def _get_or_create_datasource(self, name):
        """ Returns the data source associated with the given name, or creates
        it if it doesn't exist.
        """

        if name not in self.datasources:
            data = self.data.get_data(name)

            if type(data) in (list, tuple):
                data = array(data)

            if isinstance(data, ndarray):
                if len(data.shape) == 1:
                    ds = ArrayDataSource(data, sort_order="none")
                elif len(data.shape) == 2:
                    ds = ImageData(data=data, value_depth=1)
                elif len(data.shape) == 3 and data.shape[2] in (3,4):
                    ds = ImageData(data=data, value_depth=int(data.shape[2]))
                else:
                    raise ValueError("Unhandled array shape in creating new "
                                     "plot: %s" % str(data.shape))
            elif isinstance(data, AbstractDataSource):
                ds = data
            else:
                raise ValueError("Couldn't create datasource for data of "
                                 "type %s" % type(data))

            self.datasources[name] = ds

        return self.datasources[name]

    #------------------------------------------------------------------------
    # Event handlers
    #------------------------------------------------------------------------

    def _color_mapper_changed(self):
        for plist in self.plots.values():
            for plot in plist:
                plot.color_mapper = self.color_mapper
        self.invalidate_draw()

    def _data_changed(self, old, new):
        if old:
            old.on_trait_change(self._data_update_handler, "data_changed",
                                remove=True)
        if new:
            new.on_trait_change(self._data_update_handler, "data_changed")

    def _data_update_handler(self, name, event):
        # event should be a dict with keys "added", "removed", and "changed",
        # per the comments in AbstractPlotData.
        if "removed" in event:
            for name in event["removed"]:
                del self.datasources[name]

        if "added" in event:
            for name in event["added"]:
                self._get_or_create_datasource(name)

        if "changed" in event:
            for name in event["changed"]:
                if name in self.datasources:
                    source = self.datasources[name]
                    source.set_data(self.data.get_data(name))

    def _plots_items_changed(self, event):
        if self.legend:
            self.legend.plots = self.plots

    def _index_scale_changed(self, old, new):
        if old is None: return
        if new == old: return
        if not self.range2d: return
        if self.index_scale == "linear":
            imap = LinearMapper(range=self.index_range,
                                screen_bounds=self.index_mapper.screen_bounds,
                                stretch_data=self.index_mapper.stretch_data)
        else:
            imap = LogMapper(range=self.index_range,
                             screen_bounds=self.index_mapper.screen_bounds,
                             stretch_data=self.index_mapper.stretch_data)
        self.index_mapper = imap
        for key in self.plots:
            for plot in self.plots[key]:
                if not isinstance(plot, BaseXYPlot):
                    raise ValueError("log scale only supported on XY plots")
                if self.index_scale == "linear":
                    imap = LinearMapper(range=plot.index_range,
                                screen_bounds=plot.index_mapper.screen_bounds,
                                stretch_data=self.index_mapper.stretch_data)
                else:
                    imap = LogMapper(range=plot.index_range,
                                screen_bounds=plot.index_mapper.screen_bounds,
                                stretch_data=self.index_mapper.stretch_data)
                plot.index_mapper = imap

    def _value_scale_changed(self, old, new):
        if old is None: return
        if new == old: return
        if not self.range2d: return
        if self.value_scale == "linear":
            vmap = LinearMapper(range=self.value_range,
                                screen_bounds=self.value_mapper.screen_bounds,
                                stretch_data=self.value_mapper.stretch_data)
        else:
            vmap = LogMapper(range=self.value_range,
                             screen_bounds=self.value_mapper.screen_bounds,
                                stretch_data=self.value_mapper.stretch_data)
        self.value_mapper = vmap
        for key in self.plots:
            for plot in self.plots[key]:
                if not isinstance(plot, BaseXYPlot):
                    raise ValueError("log scale only supported on XY plots")
                if self.value_scale == "linear":
                    vmap = LinearMapper(range=plot.value_range,
                                screen_bounds=plot.value_mapper.screen_bounds,
                                stretch_data=self.value_mapper.stretch_data)
                else:
                    vmap = LogMapper(range=plot.value_range,
                                screen_bounds=plot.value_mapper.screen_bounds,
                                stretch_data=self.value_mapper.stretch_data)
                plot.value_mapper = vmap

    def __title_changed(self, old, new):
        self._overlay_change_helper(old, new)

    def _legend_changed(self, old, new):
        self._overlay_change_helper(old, new)
        if new:
            new.plots = self.plots

    def _handle_range_changed(self, name, old, new):
        """ Overrides the DataView default behavior.

        Primarily changes how the list of renderers is looked up.
        """
        mapper = getattr(self, name+"_mapper")
        if mapper.range == old:
            mapper.range = new
        if old is not None:
            for datasource in old.sources[:]:
                old.remove(datasource)
                if new is not None:
                    new.add(datasource)
        range_name = name + "_range"
        for renderer in itertools.chain(*self.plots.values()):
            if hasattr(renderer, range_name):
                setattr(renderer, range_name, new)

    #------------------------------------------------------------------------
    # Property getters and setters
    #------------------------------------------------------------------------

    def _set_legend_alignment(self, align):
        if self.legend:
            self.legend.align = align

    def _get_legend_alignment(self):
        if self.legend:
            return self.legend.align
        else:
            return None

    def _set_title(self, text):
        self._title.text = text
        if text.strip() != "":
            self._title.visible = True
        else:
            self._title.visible = False

    def _get_title(self):
        return self._title.text

    def _set_title_position(self, pos):
        if self._title is not None:
            self._title.overlay_position = pos

    def _get_title_position(self):
        if self._title is not None:
            return self._title.overlay_position
        else:
            return None

    def _set_title_font(self, font):
        old_font = self._title.font
        self._title.font = font
        self.trait_property_changed("title_font", old_font, font)

    def _get_title_font(self):
        return self._title.font
コード例 #26
0
# Group layout trait:
Layout = Trait('normal',
               TraitPrefixList('normal', 'split', 'tabbed', 'flow', 'fold'))

# Trait for the default object being edited:
AnObject = Expression('object')

# The default dock style to use:
DockStyle = dock_style_trait = Enum('fixed', 'horizontal', 'vertical', 'tab',
                                    desc="the default docking style to use")

# The category of elements dragged out of the view:
ExportType = Str(desc='the category of elements dragged out of the view')

# Delegate a trait value to the object's **container** trait:
ContainerDelegate = container_delegate = Delegate('container',
                                                  listenable=False)

# An identifier for the external help context:
HelpId = help_id_trait = Str(desc="the external help context identifier")

# A button to add to a view:
AButton = Any
#AButton = Trait( '', Str, Instance( 'traitsui.menu.Action' ) )

# The set of buttons to add to the view:
Buttons = List(AButton,
               desc='the action buttons to add to the bottom of the view')

# View trait specified by name or instance:
AView = Any
#AView = Trait( '', Str, Instance( 'traitsui.view.View' ) )
コード例 #27
0
from .view_element import ViewSubElement

from .item import Item

from .include import Include

from .ui_traits import SequenceTypes, ContainerDelegate, Orientation, Layout

from .dock_window_theme import dock_window_theme, DockWindowTheme

#-------------------------------------------------------------------------
#  Trait definitions:
#-------------------------------------------------------------------------

# Delegate trait to the object being "shadowed"
ShadowDelegate = Delegate('shadow')

# Amount of padding to add around item
Padding = Range(0, 15, desc='amount of padding to add around each item')

#-------------------------------------------------------------------------
#  'Group' class:
#-------------------------------------------------------------------------


class Group(ViewSubElement):
    """ Represents a grouping of items in a user interface view.
    """

    #-------------------------------------------------------------------------
    # Trait definitions:
コード例 #28
0
class WorkbenchWindow(ApplicationWindow):
    """ A workbench window. """

    # 'IWorkbenchWindow' interface -----------------------------------------

    # The view or editor that currently has the focus.
    active_part = Instance(IWorkbenchPart)

    # The editor manager is used to create/restore editors.
    editor_manager = Instance(IEditorManager)

    # The current selection within the window.
    selection = List()

    # The workbench that the window belongs to.
    workbench = Instance("pyface.workbench.api.IWorkbench")

    # Editors -----------------------

    # The active editor.
    active_editor = Instance(IEditor)

    # The visible (open) editors.
    editors = List(IEditor)

    # The Id of the editor area.
    editor_area_id = Constant("pyface.workbench.editors")

    # The (initial) size of the editor area (the user is free to resize it of
    # course).
    editor_area_size = Tuple((100, 100))

    # Fired when an editor is about to be opened (or restored).
    editor_opening = Delegate("layout")  # Event(IEditor)

    # Fired when an editor has been opened (or restored).
    editor_opened = Delegate("layout")  # Event(IEditor)

    # Fired when an editor is about to be closed.
    editor_closing = Delegate("layout")  # Event(IEditor)

    # Fired when an editor has been closed.
    editor_closed = Delegate("layout")  # Event(IEditor)

    # Views -------------------------

    # The active view.
    active_view = Instance(IView)

    # The available views (note that this is *all* of the views, not just those
    # currently visible).
    #
    # Views *cannot* be shared between windows as each view has a reference to
    # its toolkit-specific control etc.
    views = List(IView)

    # Perspectives -----------------#

    # The active perspective.
    active_perspective = Instance(IPerspective)

    # The available perspectives. If no perspectives are specified then the
    # a single instance of the 'Perspective' class is created.
    perspectives = List(IPerspective)

    # The Id of the default perspective.
    #
    # There are two situations in which this is used:
    #
    # 1. When the window is being created from scratch (i.e., not restored).
    #
    #    If this is the empty string, then the first perspective in the list of
    #    perspectives is shown (if there are no perspectives then an instance
    #    of the default 'Perspective' class is used). If this is *not* the
    #    empty string then the perspective with this Id is shown.
    #
    # 2. When the window is being restored.
    #
    #    If this is the empty string, then the last perspective that was
    #    visible when the window last closed is shown. If this is not the empty
    #    string then the perspective with this Id is shown.
    #
    default_perspective_id = Str()

    # 'WorkbenchWindow' interface -----------------------------------------#

    # The window layout is responsible for creating and managing the internal
    # structure of the window (i.e., it knows how to add and remove views and
    # editors etc).
    layout = Instance(WorkbenchWindowLayout)

    # 'Private' interface -------------------------------------------------#

    # The state of the window suitable for pickling etc.
    _memento = Instance(WorkbenchWindowMemento)

    # ------------------------------------------------------------------------
    # 'Window' interface.
    # ------------------------------------------------------------------------

    def open(self):
        """ Open the window.

        Overridden to make the 'opening' event vetoable.

        Return True if the window opened successfully; False if the open event
        was vetoed.

        """

        logger.debug("window %s opening", self)

        # Trait notification.
        self.opening = event = Vetoable()
        if not event.veto:
            if self.control is None:
                self._create()

            self.show(True)

            # Trait notification.
            self.opened = self

            logger.debug("window %s opened", self)

        else:
            logger.debug("window %s open was vetoed", self)

        # fixme: This is not actually part of the Pyface 'Window' API (but
        # maybe it should be). We return this to indicate whether the window
        # actually opened.
        return self.control is not None

    def close(self):
        """ Closes the window.

        Overridden to make the 'closing' event vetoable.

        Return True if the window closed successfully (or was not even open!),
        False if the close event was vetoed.

        """

        logger.debug("window %s closing", self)

        if self.control is not None:
            # Trait notification.
            self.closing = event = Vetoable()

            # fixme: Hack to mimic vetoable events!
            if not event.veto:
                # Give views and editors a chance to cleanup after themselves.
                self.destroy_views(self.views)
                self.destroy_editors(self.editors)

                # Cleanup the window layout (event handlers, etc.)
                self.layout.close()

                # Cleanup the toolkit-specific control.
                self.destroy()

                # Cleanup our reference to the control so that we can (at least
                # in theory!) be opened again.
                self.control = None

                # Trait notification.
                self.closed = self

                logger.debug("window %s closed", self)

            else:
                logger.debug("window %s close was vetoed", self)

        else:
            logger.debug("window %s is not open", self)

        # FIXME v3: This is not actually part of the Pyface 'Window' API (but
        # maybe it should be). We return this to indicate whether the window
        # actually closed.
        return self.control is None

    # ------------------------------------------------------------------------
    # Protected 'Window' interface.
    # ------------------------------------------------------------------------

    def _create_contents(self, parent):
        """ Create and return the window contents. """

        # Create the initial window layout.
        contents = self.layout.create_initial_layout(parent)

        # Save the initial window layout so that we can reset it when changing
        # to a perspective that has not been seen yet.
        self._initial_layout = self.layout.get_view_memento()

        # Are we creating the window from scratch or restoring it from a
        # memento?
        if self._memento is None:
            self._memento = WorkbenchWindowMemento()

        else:
            self._restore_contents()

        # Set the initial perspective.
        self.active_perspective = self._get_initial_perspective()

        return contents

    # ------------------------------------------------------------------------
    # 'WorkbenchWindow' interface.
    # ------------------------------------------------------------------------

    # Initializers ---------------------------------------------------------

    def _editor_manager_default(self):
        """ Trait initializer. """

        from editor_manager import EditorManager

        return EditorManager(window=self)

    def _layout_default(self):
        """ Trait initializer. """

        return WorkbenchWindowLayout(window=self)

    # Methods -------------------------------------------------------------#

    def activate_editor(self, editor):
        """ Activates an editor. """

        self.layout.activate_editor(editor)

    def activate_view(self, view):
        """ Activates a view. """

        self.layout.activate_view(view)

    def add_editor(self, editor, title=None):
        """ Adds an editor.

        If no title is specified, the editor's name is used.

        """

        if title is None:
            title = editor.name

        self.layout.add_editor(editor, title)
        self.editors.append(editor)

    def add_view(self, view, position=None, relative_to=None, size=(-1, -1)):
        """ Adds a view. """

        self.layout.add_view(view, position, relative_to, size)

        # This case allows for views that are created and added dynamically
        # (i.e. they were not even known about when the window was created).
        if not view in self.views:
            self.views.append(view)

    def close_editor(self, editor):
        """ Closes an editor. """

        self.layout.close_editor(editor)

    def close_view(self, view):
        """ Closes a view.

        fixme: Currently views are never 'closed' in the same sense as an
        editor is closed. Views are merely hidden.

        """

        self.hide_view(view)

    def create_editor(self, obj, kind=None):
        """ Create an editor for an object.

        Return None if no editor can be created for the object.

        """

        return self.editor_manager.create_editor(self, obj, kind)

    def destroy_editors(self, editors):
        """ Destroy a list of editors. """

        for editor in editors:
            if editor.control is not None:
                editor.destroy_control()

    def destroy_views(self, views):
        """ Destroy a list of views. """

        for view in views:
            if view.control is not None:
                view.destroy_control()

    def edit(self, obj, kind=None, use_existing=True):
        """ Edit an object.

        'kind' is simply passed through to the window's editor manager to
        allow it to create a particular kind of editor depending on context
        etc.

        If 'use_existing' is True and the object is already being edited in
        the window then the existing editor will be activated (i.e., given
        focus, brought to the front, etc.).

        If 'use_existing' is False, then a new editor will be created even if
        one already exists.

        """

        if use_existing:
            # Is the object already being edited in the window?
            editor = self.get_editor(obj, kind)

            if editor is not None:
                # If so, activate the existing editor (i.e., bring it to the
                # front, give it the focus etc).
                self.activate_editor(editor)
                return editor

        # Otherwise, create an editor for it.
        editor = self.create_editor(obj, kind)

        if editor is None:
            logger.warn("no editor for object %s", obj)

        self.add_editor(editor)
        self.activate_editor(editor)

        return editor

    def get_editor(self, obj, kind=None):
        """ Return the editor that is editing an object.

        Return None if no such editor exists.

        """

        return self.editor_manager.get_editor(self, obj, kind)

    def get_editor_by_id(self, id):
        """ Return the editor with the specified Id.

        Return None if no such editor exists.

        """

        for editor in self.editors:
            if editor.id == id:
                break

        else:
            editor = None

        return editor

    def get_part_by_id(self, id):
        """ Return the workbench part with the specified Id.

        Return None if no such part exists.

        """

        return self.get_view_by_id(id) or self.get_editor_by_id(id)

    def get_perspective_by_id(self, id):
        """ Return the perspective with the specified Id.

        Return None if no such perspective exists.

        """

        for perspective in self.perspectives:
            if perspective.id == id:
                break

        else:
            if id == Perspective.DEFAULT_ID:
                perspective = Perspective()

            else:
                perspective = None

        return perspective

    def get_perspective_by_name(self, name):
        """ Return the perspective with the specified name.

        Return None if no such perspective exists.

        """

        for perspective in self.perspectives:
            if perspective.name == name:
                break

        else:
            perspective = None

        return perspective

    def get_view_by_id(self, id):
        """ Return the view with the specified Id.

        Return None if no such view exists.

        """

        for view in self.views:
            if view.id == id:
                break

        else:
            view = None

        return view

    def hide_editor_area(self):
        """ Hide the editor area. """

        self.layout.hide_editor_area()

    def hide_view(self, view):
        """ Hide a view. """

        self.layout.hide_view(view)

    def refresh(self):
        """ Refresh the window to reflect any changes. """

        self.layout.refresh()

    def reset_active_perspective(self):
        """ Reset the active perspective back to its original contents. """

        perspective = self.active_perspective

        # If the perspective has been seen before then delete its memento.
        if perspective.id in self._memento.perspective_mementos:
            # Remove the perspective's memento.
            del self._memento.perspective_mementos[perspective.id]

        # Re-display the perspective (because a memento no longer exists for
        # the perspective, its 'create_contents' method will be called again).
        self._show_perspective(perspective, perspective)

    def reset_all_perspectives(self):
        """ Reset all perspectives back to their original contents. """

        # Remove all perspective mementos (except user perspectives).
        for id in self._memento.perspective_mementos.keys():
            if not id.startswith("__user_perspective"):
                del self._memento.perspective_mementos[id]

        # Re-display the active perspective.
        self._show_perspective(
            self.active_perspective, self.active_perspective
        )

    def reset_editors(self):
        """ Activate the first editor in every tab. """

        self.layout.reset_editors()

    def reset_views(self):
        """ Activate the first view in every tab. """

        self.layout.reset_views()

    def show_editor_area(self):
        """ Show the editor area. """

        self.layout.show_editor_area()

    def show_view(self, view):
        """ Show a view. """

        # If the view is already in the window layout, but hidden, then just
        # show it.
        #
        # fixme: This is a little gorpy, reaching into the window layout here,
        # but currently this is the only thing that knows whether or not the
        # view exists but is hidden.
        if self.layout.contains_view(view):
            self.layout.show_view(view)

        # Otherwise, we have to add the view to the layout.
        else:
            self._add_view_in_default_position(view)
            self.refresh()

        return

    # Methods for saving and restoring the layout -------------------------#

    def get_memento(self):
        """ Return the state of the window suitable for pickling etc. """

        # The size and position of the window.
        self._memento.size = self.size
        self._memento.position = self.position

        # The Id of the active perspective.
        self._memento.active_perspective_id = self.active_perspective.id

        # The layout of the active perspective.
        self._memento.perspective_mementos[self.active_perspective.id] = (
            self.layout.get_view_memento(),
            self.active_view and self.active_view.id or None,
            self.layout.is_editor_area_visible(),
        )

        # The layout of the editor area.
        self._memento.editor_area_memento = self.layout.get_editor_memento()

        # Any extra toolkit-specific data.
        self._memento.toolkit_data = self.layout.get_toolkit_memento()

        return self._memento

    def set_memento(self, memento):
        """ Restore the state of the window from a memento. """

        # All we do here is save a reference to the memento - we don't actually
        # do anything with it until the window is opened.
        #
        # This obviously means that you can't set the memento of a window
        # that is already open, but I can't see a use case for that anyway!
        self._memento = memento

        return

    # ------------------------------------------------------------------------
    # Private interface.
    # ------------------------------------------------------------------------

    def _add_view_in_default_position(self, view):
        """ Adds a view in its 'default' position. """

        # Is the view in the current perspectives contents list? If it is then
        # we use the positioning information in the perspective item. Otherwise
        # we will use the default positioning specified in the view itself.
        item = self._get_perspective_item(self.active_perspective, view)
        if item is None:
            item = view

        # fixme: This only works because 'PerspectiveItem' and 'View' have the
        # identical 'position', 'relative_to', 'width' and 'height' traits! We
        # need to unify these somehow!
        relative_to = self.get_view_by_id(item.relative_to)
        size = (item.width, item.height)

        self.add_view(view, item.position, relative_to, size)

    def _get_initial_perspective(self, *methods):
        """ Return the initial perspective. """

        methods = [
            # If a default perspective was specified then we prefer that over
            # any other perspective.
            self._get_default_perspective,
            # If there was no default perspective then try the perspective that
            # was active the last time the application was run.
            self._get_previous_perspective,
            # If there was no previous perspective, then try the first one that
            # we know about.
            self._get_first_perspective,
        ]

        for method in methods:
            perspective = method()
            if perspective is not None:
                break

        # If we have no known perspectives, make a new blank one up.
        else:
            logger.warn("no known perspectives - creating a new one")
            perspective = Perspective()

        return perspective

    def _get_default_perspective(self):
        """ Return the default perspective.

        Return None if no default perspective was specified or it no longer
        exists.

        """

        id = self.default_perspective_id

        if len(id) > 0:
            perspective = self.get_perspective_by_id(id)
            if perspective is None:
                logger.warn("default perspective %s no longer available", id)

        else:
            perspective = None

        return perspective

    def _get_previous_perspective(self):
        """ Return the previous perspective.

        Return None if there has been no previous perspective or it no longer
        exists.

        """

        id = self._memento.active_perspective_id

        if len(id) > 0:
            perspective = self.get_perspective_by_id(id)
            if perspective is None:
                logger.warn("previous perspective %s no longer available", id)

        else:
            perspective = None

        return perspective

    def _get_first_perspective(self):
        """ Return the first perspective in our list of perspectives.

        Return None if no perspectives have been defined.

        """

        if len(self.perspectives) > 0:
            perspective = self.perspectives[0]

        else:
            perspective = None

        return perspective

    def _get_perspective_item(self, perspective, view):
        """ Return the perspective item for a view.

        Return None if the view is not mentioned in the perspectives contents.

        """

        # fixme: Errrr, shouldn't this be a method on the window?!?
        for item in perspective.contents:
            if item.id == view.id:
                break

        else:
            item = None

        return item

    def _hide_perspective(self, perspective):
        """ Hide a perspective. """

        # fixme: This is a bit ugly but... when we restore the layout we ignore
        # the default view visibility.
        for view in self.views:
            view.visible = False

        # Save the current layout of the perspective.
        self._memento.perspective_mementos[perspective.id] = (
            self.layout.get_view_memento(),
            self.active_view and self.active_view.id or None,
            self.layout.is_editor_area_visible(),
        )

    def _show_perspective(self, old, new):
        """ Show a perspective. """

        # If the perspective has been seen before then restore it.
        memento = self._memento.perspective_mementos.get(new.id)

        if memento is not None:
            # Show the editor area?
            # We need to set the editor area before setting the views
            if len(memento) == 2:
                logger.warning("Restoring perspective from an older version.")
                editor_area_visible = True
            else:
                editor_area_visible = memento[2]

            # Show the editor area if it is set to be visible
            if editor_area_visible:
                self.show_editor_area()
            else:
                self.hide_editor_area()
                self.active_editor = None

            # Now set the views
            view_memento, active_view_id = memento[:2]
            self.layout.set_view_memento(view_memento)

            # Make sure the active part, view and editor reflect the new
            # perspective.
            view = self.get_view_by_id(active_view_id)
            if view is not None:
                self.active_view = view

        # Otherwise, this is the first time the perspective has been seen
        # so create it.
        else:
            if old is not None:
                # Reset the window layout to its initial state.
                self.layout.set_view_memento(self._initial_layout)

            # Create the perspective in the window.
            new.create(self)

            # Make sure the active part, view and editor reflect the new
            # perspective.
            self.active_view = None

            # Show the editor area?
            if new.show_editor_area:
                self.show_editor_area()
            else:
                self.hide_editor_area()
                self.active_editor = None

        # Inform the perspective that it has been shown.
        new.show(self)

        # This forces the dock window to update its layout.
        if old is not None:
            self.refresh()

    def _restore_contents(self):
        """ Restore the contents of the window. """

        self.layout.set_editor_memento(self._memento.editor_area_memento)

        self.size = self._memento.size
        self.position = self._memento.position

        # Set the toolkit-specific data last because it may override the generic
        # implementation.
        # FIXME: The primary use case is to let Qt restore the window's geometry
        # wholesale, including maximization state. If we ever go Qt-only, this
        # is a good area to refactor.
        self.layout.set_toolkit_memento(self._memento)

        return

    # Trait change handlers ------------------------------------------------

    # Static ----

    def _active_perspective_changed(self, old, new):
        """ Static trait change handler. """

        logger.debug("active perspective changed from <%s> to <%s>", old, new)

        # Hide the old perspective...
        if old is not None:
            self._hide_perspective(old)

        # ... and show the new one.
        if new is not None:
            self._show_perspective(old, new)

    def _active_editor_changed(self, old, new):
        """ Static trait change handler. """

        logger.debug("active editor changed from <%s> to <%s>", old, new)
        self.active_part = new

    def _active_part_changed(self, old, new):
        """ Static trait change handler. """

        if new is None:
            self.selection = []

        else:
            self.selection = new.selection

        logger.debug("active part changed from <%s> to <%s>", old, new)

    def _active_view_changed(self, old, new):
        """ Static trait change handler. """

        logger.debug("active view changed from <%s> to <%s>", old, new)
        self.active_part = new

    def _views_changed(self, old, new):
        """ Static trait change handler. """

        # Cleanup any old views.
        for view in old:
            view.window = None

        # Initialize any new views.
        for view in new:
            view.window = self

    def _views_items_changed(self, event):
        """ Static trait change handler. """

        # Cleanup any old views.
        for view in event.removed:
            view.window = None

        # Initialize any new views.
        for view in event.added:
            view.window = self

        return

    # Dynamic ----

    @on_trait_change("layout.editor_closed")
    def _on_editor_closed(self, editor):
        """ Dynamic trait change handler. """

        if editor is None or editor is Undefined:
            return

        index = self.editors.index(editor)
        del self.editors[index]
        if editor is self.active_editor:
            if len(self.editors) > 0:
                index = min(index, len(self.editors) - 1)
                # If the user closed the editor manually then this method is
                # being called from a toolkit-specific event handler. Because
                # of that we have to make sure that we don't change the focus
                # from within this method directly hence we activate the editor
                # later in the GUI thread.
                GUI.invoke_later(self.activate_editor, self.editors[index])

            else:
                self.active_editor = None

        return

    @on_trait_change("editors.has_focus")
    def _on_editor_has_focus_changed(self, obj, trait_name, old, new):
        """ Dynamic trait change handler. """

        if trait_name == "has_focus" and new:
            self.active_editor = obj

        return

    @on_trait_change("views.has_focus")
    def _has_focus_changed_for_view(self, obj, trait_name, old, new):
        """ Dynamic trait change handler. """

        if trait_name == "has_focus" and new:
            self.active_view = obj

        return

    @on_trait_change("views.visible")
    def _visible_changed_for_view(self, obj, trait_name, old, new):
        """ Dynamic trait change handler. """

        if trait_name == "visible":
            if not new:
                if obj is self.active_view:
                    self.active_view = None

        return
コード例 #29
0
ファイル: sources.py プロジェクト: linomp/acoular
class UncorrelatedNoiseSource(SamplesGenerator):
    """
    Class to simulate white or pink noise as uncorrelated signal at each
    channel.
    
    The output is being generated via the :meth:`result` generator.
    """

    #: Type of noise to generate at the channels.
    #: The `~acoular.signals.SignalGenerator`-derived class has to
    # feature the parameter "seed" (i.e. white or pink noise).
    signal = Trait(SignalGenerator, desc="type of noise")

    #: Array with seeds for random number generator.
    #: When left empty, arange(:attr:`numchannels`) + :attr:`signal`.seed
    #: will be used.
    seed = CArray(dtype=uint32, desc="random seed values")

    #: Number of channels in output; is set automatically /
    #: depends on used microphone geometry.
    numchannels = Delegate('mics', 'num_mics')

    #: :class:`~acoular.microphones.MicGeom` object that provides the microphone locations.
    mics = Trait(MicGeom, desc="microphone geometry")

    # --- List of backwards compatibility traits and their setters/getters -----------

    # Microphone locations.
    # Deprecated! Use :attr:`mics` trait instead.
    mpos = Property()

    def _get_mpos(self):
        return self.mics

    def _set_mpos(self, mpos):
        warn("Deprecated use of 'mpos' trait. ", Warning, stacklevel=2)
        self.mics = mpos

    # --- End of backwards compatibility traits --------------------------------------

    #: Start time of the signal in seconds, defaults to 0 s.
    start_t = Float(0.0, desc="signal start time")

    #: Start time of the data aquisition at microphones in seconds,
    #: defaults to 0 s.
    start = Float(0.0, desc="sample start time")

    #: Number of samples is set automatically /
    #: depends on :attr:`signal`.
    numsamples = Delegate('signal')

    #: Sampling frequency of the signal; is set automatically /
    #: depends on :attr:`signal`.
    sample_freq = Delegate('signal')

    # internal identifier
    digest = Property(
        depends_on = ['mics.digest', 'signal.rms', 'signal.numsamples', \
        'signal.sample_freq', 'signal.__class__' , 'seed', 'loc', \
         'start_t', 'start', '__class__'],
        )

    @cached_property
    def _get_digest(self):
        return digest(self)

    def result(self, num=128):
        """
        Python generator that yields the output at microphones block-wise.
                
        Parameters
        ----------
        num : integer, defaults to 128
            This parameter defines the size of the blocks to be yielded
            (i.e. the number of samples per block) .
        
        Returns
        -------
        Samples in blocks of shape (num, numchannels). 
            The last block may be shorter than num.
        """

        Noise = self.signal.__class__
        # create or get the array of random seeds
        if not self.seed:
            seed = arange(self.numchannels) + self.signal.seed
        elif self.seed.shape == (self.numchannels, ):
            seed = self.seed
        else:
            raise ValueError(\
               "Seed array expected to be of shape (%i,), but has shape %s." \
                % (self.numchannels, str(self.seed.shape)) )

        # create array with [numchannels] noise signal tracks
        signal = array([Noise(seed = s,
                              numsamples = self.numsamples,
                              sample_freq = self.sample_freq,
                              rms = self.signal.rms).signal() \
                        for s in seed]).T

        n = num
        while n <= self.numsamples:
            yield signal[n - num:n, :]
            n += num
        else:
            yield signal[n - num:, :]
コード例 #30
0
ファイル: sources.py プロジェクト: linomp/acoular
class PointSource(SamplesGenerator):
    """
    Class to define a fixed point source with an arbitrary signal.
    This can be used in simulations.
    
    The output is being generated via the :meth:`result` generator.
    """

    #:  Emitted signal, instance of the :class:`~acoular.signals.SignalGenerator` class.
    signal = Trait(SignalGenerator)

    #: Location of source in (`x`, `y`, `z`) coordinates (left-oriented system).
    loc = Tuple((0.0, 0.0, 1.0), desc="source location")

    #: Number of channels in output, is set automatically /
    #: depends on used microphone geometry.
    numchannels = Delegate('mics', 'num_mics')

    #: :class:`~acoular.microphones.MicGeom` object that provides the microphone locations.
    mics = Trait(MicGeom, desc="microphone geometry")

    #: :class:`~acoular.environments.Environment` or derived object,
    #: which provides information about the sound propagation in the medium.
    env = Trait(Environment(), Environment)

    # --- List of backwards compatibility traits and their setters/getters -----------

    # Microphone locations.
    # Deprecated! Use :attr:`mics` trait instead.
    mpos = Property()

    def _get_mpos(self):
        return self.mics

    def _set_mpos(self, mpos):
        warn("Deprecated use of 'mpos' trait. ", Warning, stacklevel=2)
        self.mics = mpos

    # The speed of sound.
    # Deprecated! Only kept for backwards compatibility.
    # Now governed by :attr:`env` trait.
    c = Property()

    def _get_c(self):
        return self.env.c

    def _set_c(self, c):
        warn("Deprecated use of 'c' trait. ", Warning, stacklevel=2)
        self.env.c = c

    # --- End of backwards compatibility traits --------------------------------------

    #: Start time of the signal in seconds, defaults to 0 s.
    start_t = Float(0.0, desc="signal start time")

    #: Start time of the data aquisition at microphones in seconds,
    #: defaults to 0 s.
    start = Float(0.0, desc="sample start time")

    #: Upsampling factor, internal use, defaults to 16.
    up = Int(16, desc="upsampling factor")

    #: Number of samples, is set automatically /
    #: depends on :attr:`signal`.
    numsamples = Delegate('signal')

    #: Sampling frequency of the signal, is set automatically /
    #: depends on :attr:`signal`.
    sample_freq = Delegate('signal')

    # internal identifier
    digest = Property(
        depends_on = ['mics.digest', 'signal.digest', 'loc', \
         'env.digest', 'start_t', 'start', 'up', '__class__'],
        )

    @cached_property
    def _get_digest(self):
        return digest(self)

    def result(self, num=128):
        """
        Python generator that yields the output at microphones block-wise.
                
        Parameters
        ----------
        num : integer, defaults to 128
            This parameter defines the size of the blocks to be yielded
            (i.e. the number of samples per block) .
        
        Returns
        -------
        Samples in blocks of shape (num, numchannels). 
            The last block may be shorter than num.
        """
        #If signal samples are needed for te < t_start, then samples are taken
        #from the end of the calculated signal.

        signal = self.signal.usignal(self.up)
        out = empty((num, self.numchannels))
        # distances
        rm = self.env._r(array(self.loc).reshape((3, 1)), self.mics.mpos)
        # emission time relative to start_t (in samples) for first sample
        ind = (-rm / self.env.c - self.start_t + self.start) * self.sample_freq
        i = 0
        n = self.numsamples
        while n:
            n -= 1
            try:
                out[i] = signal[array(0.5 + ind * self.up, dtype=int64)] / rm
                ind += 1.
                i += 1
                if i == num:
                    yield out
                    i = 0
            except IndexError:  #if no more samples available from the source
                break
        if i > 0:  # if there are still samples to yield
            yield out[:i]