示例#1
0
 def __init__(self, component=None, **kwargs):
     # TODO: change this back to a factory in the instance trait some day
     self.tick_generator = DefaultTickGenerator()
     # Override init so that our component gets set last.  We want the
     # _component_changed() event handler to get run last.
     super(PlotAxis, self).__init__(**kwargs)
     if component is not None:
         self.component = component
示例#2
0
 def __init__(self, component=None, **kwargs):
     # TODO: change this back to a factory in the instance trait some day
     self.tick_generator = DefaultTickGenerator()
     # Override init so that our component gets set last.  We want the
     # _component_changed() event handler to get run last.
     super(PlotAxis, self).__init__(**kwargs)
     if component is not None:
         self.component = component
示例#3
0
class PlotAxis(AbstractOverlay):
    """
    The PlotAxis is a visual component that can be rendered on its own as
    a standalone component or attached as an overlay to another component.
    (To attach it as an overlay, set its **component** attribute.)

    When it is attached as an overlay, it draws into the padding around
    the component.
    """

    # The mapper that drives this axis.
    mapper = Instance(AbstractMapper)

    # Keep an origin for plots that aren't attached to a component
    origin = Enum("bottom left", "top left", "bottom right", "top right")

    # The text of the axis title.
    title = Trait('', Str, Unicode) #May want to add PlotLabel option

    # The font of the title.
    title_font = KivaFont('modern 12')

    # The spacing between the axis line and the title
    title_spacing = Trait('auto', 'auto', Float)

    # The color of the title.
    title_color = ColorTrait("black")

    # The thickness (in pixels) of each tick.
    tick_weight = Float(1.0)

    # The color of the ticks.
    tick_color = ColorTrait("black")

    # The font of the tick labels.
    tick_label_font = KivaFont('modern 10')

    # The color of the tick labels.
    tick_label_color = ColorTrait("black")

    # The rotation of the tick labels.
    tick_label_rotate_angle = Float(0)

    # Whether to align to corners or edges (corner is better for 45 degree rotation)
    tick_label_alignment = Enum('edge', 'corner')

    # The margin around the tick labels.
    tick_label_margin = Int(2)

    # The distance of the tick label from the axis.
    tick_label_offset = Float(8.)

    # Whether the tick labels appear to the inside or the outside of the plot area
    tick_label_position = Enum("outside", "inside")

    # A callable that is passed the numerical value of each tick label and
    # that returns a string.
    tick_label_formatter = Callable(DEFAULT_TICK_FORMATTER)

    # The number of pixels by which the ticks extend into the plot area.
    tick_in = Int(5)

    # The number of pixels by which the ticks extend into the label area.
    tick_out = Int(5)

    # Are ticks visible at all?
    tick_visible = Bool(True)

    # The dataspace interval between ticks.
    tick_interval = Trait('auto', 'auto', Float)

    # A callable that implements the AbstractTickGenerator interface.
    tick_generator = Instance(AbstractTickGenerator)

    # The location of the axis relative to the plot.  This determines where
    # the axis title is located relative to the axis line.
    orientation = Enum("top", "bottom", "left", "right")

    # Is the axis line visible?
    axis_line_visible = Bool(True)

    # The color of the axis line.
    axis_line_color = ColorTrait("black")

    # The line thickness (in pixels) of the axis line.
    axis_line_weight = Float(1.0)

    # The dash style of the axis line.
    axis_line_style = LineStyle('solid')

    # A special version of the axis line that is more useful for geophysical
    # plots.
    small_haxis_style = Bool(False)

    # Does the axis ensure that its end labels fall within its bounding area?
    ensure_labels_bounded = Bool(False)

    # Does the axis prevent the ticks from being rendered outside its bounds?
    # This flag is off by default because the standard axis *does* render ticks
    # that encroach on the plot area.
    ensure_ticks_bounded = Bool(False)

    # Fired when the axis's range bounds change.
    updated = Event

    #------------------------------------------------------------------------
    # Override default values of inherited traits
    #------------------------------------------------------------------------

    # Background color (overrides AbstractOverlay). Axes usually let the color of
    # the container show through.
    bgcolor = ColorTrait("transparent")

    # Dimensions that the axis is resizable in (overrides PlotComponent).
    # Typically, axes are resizable in both dimensions.
    resizable = "hv"

    #------------------------------------------------------------------------
    # Private Traits
    #------------------------------------------------------------------------

    # Cached position calculations

    _tick_list = List  # These are caches of their respective positions
    _tick_positions = Any #List
    _tick_label_list = Any
    _tick_label_positions = Any
    _tick_label_bounding_boxes = List
    _major_axis_size = Float
    _minor_axis_size = Float
    _major_axis = Array
    _title_orientation = Array
    _title_angle = Float
    _origin_point = Array
    _inside_vector = Array
    _axis_vector = Array
    _axis_pixel_vector = Array
    _end_axis_point = Array


    ticklabel_cache = List
    _cache_valid = Bool(False)


    #------------------------------------------------------------------------
    # Public methods
    #------------------------------------------------------------------------

    def __init__(self, component=None, **kwargs):
        # TODO: change this back to a factory in the instance trait some day
        self.tick_generator = DefaultTickGenerator()
        # Override init so that our component gets set last.  We want the
        # _component_changed() event handler to get run last.
        super(PlotAxis, self).__init__(**kwargs)
        if component is not None:
            self.component = component

    def invalidate(self):
        """ Invalidates the pre-computed layout and scaling data.
        """
        self._reset_cache()
        self.invalidate_draw()
        return

    def traits_view(self):
        """ Returns a View instance for use with Traits UI.  This method is
        called automatically be the Traits framework when .edit_traits() is
        invoked.
        """
        from axis_view import AxisView
        return AxisView


    #------------------------------------------------------------------------
    # PlotComponent and AbstractOverlay interface
    #------------------------------------------------------------------------

    def _do_layout(self, *args, **kw):
        """ Tells this component to do layout at a given size.

        Overrides Component.
        """
        if self.use_draw_order and self.component is not None:
            self._layout_as_overlay(*args, **kw)
        else:
            super(PlotAxis, self)._do_layout(*args, **kw)
        return

    def overlay(self, component, gc, view_bounds=None, mode='normal'):
        """ Draws this component overlaid on another component.

        Overrides AbstractOverlay.
        """
        if not self.visible:
            return
        self._draw_component(gc, view_bounds, mode, component)
        return

    def _draw_overlay(self, gc, view_bounds=None, mode='normal'):
        """ Draws the overlay layer of a component.

        Overrides PlotComponent.
        """
        self._draw_component(gc, view_bounds, mode)
        return

    def _draw_component(self, gc, view_bounds=None, mode='normal', component=None):
        """ Draws the component.

        This method is preserved for backwards compatibility. Overrides
        PlotComponent.
        """
        if not self.visible:
            return

        if not self._cache_valid:
            if component is not None:
                self._calculate_geometry_overlay(component)
            else:
                self._calculate_geometry()
            self._compute_tick_positions(gc, component)
            self._compute_labels(gc)

        with gc:
            # slight optimization: if we set the font correctly on the
            # base gc before handing it in to our title and tick labels,
            # their set_font() won't have to do any work.
            gc.set_font(self.tick_label_font)

            if self.axis_line_visible:
                self._draw_axis_line(gc, self._origin_point, self._end_axis_point)
            if self.title:
                self._draw_title(gc)

            self._draw_ticks(gc)
            self._draw_labels(gc)

        self._cache_valid = True
        return


    #------------------------------------------------------------------------
    # Private draw routines
    #------------------------------------------------------------------------

    def _layout_as_overlay(self, size=None, force=False):
        """ Lays out the axis as an overlay on another component.
        """
        if self.component is not None:
            if self.orientation in ("left", "right"):
                self.y = self.component.y
                self.height = self.component.height
                if self.orientation == "left":
                    self.width = self.component.padding_left
                    self.x = self.component.outer_x
                elif self.orientation == "right":
                    self.width = self.component.padding_right
                    self.x = self.component.x2 + 1
            else:
                self.x = self.component.x
                self.width = self.component.width
                if self.orientation == "bottom":
                    self.height = self.component.padding_bottom
                    self.y = self.component.outer_y
                elif self.orientation == "top":
                    self.height = self.component.padding_top
                    self.y = self.component.y2 + 1
        return

    def _draw_axis_line(self, gc, startpoint, endpoint):
        """ Draws the line for the axis.
        """
        with gc:
            gc.set_antialias(0)
            gc.set_line_width(self.axis_line_weight)
            gc.set_stroke_color(self.axis_line_color_)
            gc.set_line_dash(self.axis_line_style_)
            gc.move_to(*around(startpoint))
            gc.line_to(*around(endpoint))
            gc.stroke_path()
        return


    def _draw_title(self, gc, label=None, axis_offset=None):
        """ Draws the title for the axis.
        """
        if label is None:
            title_label = Label(text=self.title,
                                font=self.title_font,
                                color=self.title_color,
                                rotate_angle=self.title_angle)
        else:
            title_label = label

        # get the _rotated_ bounding box of the label
        tl_bounds = array(title_label.get_bounding_box(gc), float64)
        text_center_to_corner = -tl_bounds/2.0
        # which axis are we moving away from the axis line along?
        axis_index = self._major_axis.argmin()

        if self.title_spacing != 'auto':
            axis_offset = self.title_spacing

        if (self.title_spacing) and (axis_offset is None ):
            if not self.ticklabel_cache:
                axis_offset = 25
            else:
                axis_offset = max([l._bounding_box[axis_index] for l in self.ticklabel_cache]) * 1.3

        offset = (self._origin_point+self._end_axis_point)/2
        axis_dist = self.tick_out + tl_bounds[axis_index]/2.0 + axis_offset
        offset -= self._inside_vector * axis_dist
        offset += text_center_to_corner

        gc.translate_ctm(*offset)
        title_label.draw(gc)
        gc.translate_ctm(*(-offset))
        return


    def _draw_ticks(self, gc):
        """ Draws the tick marks for the axis.
        """
        if not self.tick_visible:
            return
        gc.set_stroke_color(self.tick_color_)
        gc.set_line_width(self.tick_weight)
        gc.set_antialias(False)
        gc.begin_path()
        tick_in_vector = self._inside_vector*self.tick_in
        tick_out_vector = self._inside_vector*self.tick_out
        for tick_pos in self._tick_positions:
            gc.move_to(*(tick_pos + tick_in_vector))
            gc.line_to(*(tick_pos - tick_out_vector))
        gc.stroke_path()
        return

    def _draw_labels(self, gc):
        """ Draws the tick labels for the axis.
        """
        # which axis are we moving away from the axis line along?
        axis_index = self._major_axis.argmin()

        inside_vector = self._inside_vector
        if self.tick_label_position == "inside":
            inside_vector = -inside_vector

        for i in range(len(self._tick_label_positions)):
            #We want a more sophisticated scheme than just 2 decimals all the time
            ticklabel = self.ticklabel_cache[i]
            tl_bounds = self._tick_label_bounding_boxes[i]

            #base_position puts the tick label at a point where the vector
            #extending from the tick mark inside 8 units
            #just touches the rectangular bounding box of the tick label.
            #Note: This is not necessarily optimal for non
            #horizontal/vertical axes.  More work could be done on this.

            base_position = self._tick_label_positions[i].copy()
            axis_dist = self.tick_label_offset + tl_bounds[axis_index]/2.0
            base_position -= inside_vector * axis_dist
            base_position -= tl_bounds/2.0

            if self.tick_label_alignment == 'corner':
                if self.orientation in ("top", "bottom"):
                    base_position[0] += tl_bounds[0]/2.0
                elif self.orientation == "left":
                    base_position[1] -= tl_bounds[1]/2.0
                elif self.orientation == "right":
                    base_position[1] += tl_bounds[1]/2.0

            if self.ensure_labels_bounded:
                bound_idx = self._major_axis.argmax()
                if i == 0:
                    base_position[bound_idx] = max(base_position[bound_idx],
                                                   self._origin_point[bound_idx])
                elif i == len(self._tick_label_positions)-1:
                    base_position[bound_idx] = min(base_position[bound_idx],
                                                   self._end_axis_point[bound_idx] - \
                                                   tl_bounds[bound_idx])

            tlpos = around(base_position)
            gc.translate_ctm(*tlpos)
            ticklabel.draw(gc)
            gc.translate_ctm(*(-tlpos))
        return


    #------------------------------------------------------------------------
    # Private methods for computing positions and layout
    #------------------------------------------------------------------------

    def _reset_cache(self):
        """ Clears the cached tick positions, labels, and label positions.
        """
        self._tick_positions = []
        self._tick_label_list = []
        self._tick_label_positions = []
        return

    def _compute_tick_positions(self, gc, overlay_component=None):
        """ Calculates the positions for the tick marks.
        """
        if (self.mapper is None):
            self._reset_cache()
            self._cache_valid = True
            return

        datalow = self.mapper.range.low
        datahigh = self.mapper.range.high
        screenhigh = self.mapper.high_pos
        screenlow = self.mapper.low_pos
        if overlay_component is not None:
            origin = getattr(overlay_component, 'origin', 'bottom left')
        else:
            origin = self.origin
        if self.orientation in ("top", "bottom"):
            if "right" in origin:
                flip_from_gc = True
            else:
                flip_from_gc = False
        elif self.orientation in ("left", "right"):
            if "top" in origin:
                flip_from_gc = True
            else:
                flip_from_gc = False
        if flip_from_gc:
            screenlow, screenhigh = screenhigh, screenlow

        if (datalow == datahigh) or (screenlow == screenhigh) or \
           (datalow in [inf, -inf]) or (datahigh in [inf, -inf]):
            self._reset_cache()
            self._cache_valid = True
            return

        if datalow > datahigh:
            raise RuntimeError, "DataRange low is greater than high; unable to compute axis ticks."

        if not self.tick_generator:
            return

        if hasattr(self.tick_generator, "get_ticks_and_labels"):
            # generate ticks and labels simultaneously
            tmp = self.tick_generator.get_ticks_and_labels(datalow, datahigh,
                                                screenlow, screenhigh)
            if len(tmp) == 0:
                tick_list = []
                labels = []
            else:
                tick_list, labels = tmp
            # compute the labels here
            self.ticklabel_cache = [Label(text=lab,
                                          font=self.tick_label_font,
                                          color=self.tick_label_color) \
                                    for lab in labels]
            self._tick_label_bounding_boxes = [array(ticklabel.get_bounding_box(gc), float64) \
                                               for ticklabel in self.ticklabel_cache]
        else:
            scale = 'log' if isinstance(self.mapper, LogMapper) else 'linear'
            if self.small_haxis_style:
                tick_list = array([datalow, datahigh])
            else:
                tick_list = array(self.tick_generator.get_ticks(datalow, datahigh,
                                                                datalow, datahigh,
                                                                self.tick_interval,
                                                                use_endpoints=False,
                                                                scale=scale), float64)

        mapped_tick_positions = (array(self.mapper.map_screen(tick_list))-screenlow) / \
                                            (screenhigh-screenlow)
        self._tick_positions = around(array([self._axis_vector*tickpos + self._origin_point \
                                for tickpos in mapped_tick_positions]))
        self._tick_label_list = tick_list
        self._tick_label_positions = self._tick_positions
        return


    def _compute_labels(self, gc):
        """Generates the labels for tick marks.

        Waits for the cache to become invalid.
        """
        # tick labels are already computed
        if hasattr(self.tick_generator, "get_ticks_and_labels"):
            return

        formatter = self.tick_label_formatter
        def build_label(val):
            tickstring = formatter(val) if formatter is not None else str(val)
            return Label(text=tickstring,
                         font=self.tick_label_font,
                         color=self.tick_label_color,
                         rotate_angle=self.tick_label_rotate_angle,
                         margin=self.tick_label_margin)

        self.ticklabel_cache = [build_label(val) for val in self._tick_label_list]
        self._tick_label_bounding_boxes = [array(ticklabel.get_bounding_box(gc), float)
                                               for ticklabel in self.ticklabel_cache]
        return


    def _calculate_geometry(self):
        origin = self.origin
        screenhigh = self.mapper.high_pos
        screenlow = self.mapper.low_pos

        if self.orientation in ('top', 'bottom'):
            self._major_axis_size = self.bounds[0]
            self._minor_axis_size = self.bounds[1]
            self._major_axis = array([1., 0.])
            self._title_orientation = array([0.,1.])
            self.title_angle = 0.0
            if self.orientation == 'top':
                self._origin_point = array(self.position)
                self._inside_vector = array([0.,-1.])
            else: #self.oriention == 'bottom'
                self._origin_point = array(self.position) + array([0., self.bounds[1]])
                self._inside_vector = array([0., 1.])
            if "right" in origin:
                screenlow, screenhigh = screenhigh, screenlow

        elif self.orientation in ('left', 'right'):
            self._major_axis_size = self.bounds[1]
            self._minor_axis_size = self.bounds[0]
            self._major_axis = array([0., 1.])
            self._title_orientation = array([-1., 0])
            if self.orientation == 'left':
                self._origin_point = array(self.position) + array([self.bounds[0], 0.])
                self._inside_vector = array([1., 0.])
                self.title_angle = 90.0
            else: #self.orientation == 'right'
                self._origin_point = array(self.position)
                self._inside_vector = array([-1., 0.])
                self.title_angle = 270.0
            if "top" in origin:
                screenlow, screenhigh = screenhigh, screenlow

        if self.ensure_ticks_bounded:
            self._origin_point -= self._inside_vector*self.tick_in

        self._end_axis_point = abs(screenhigh-screenlow)*self._major_axis + self._origin_point
        self._axis_vector = self._end_axis_point - self._origin_point
        # This is the vector that represents one unit of data space in terms of screen space.
        self._axis_pixel_vector = self._axis_vector/sqrt(dot(self._axis_vector,self._axis_vector))
        return


    def _calculate_geometry_overlay(self, overlay_component=None):
        if overlay_component is None:
            overlay_component = self
        component_origin = getattr(overlay_component, "origin", 'bottom left')

        screenhigh = self.mapper.high_pos
        screenlow = self.mapper.low_pos

        if self.orientation in ('top', 'bottom'):
            self._major_axis_size = overlay_component.bounds[0]
            self._minor_axis_size = overlay_component.bounds[1]
            self._major_axis = array([1., 0.])
            self._title_orientation = array([0.,1.])
            self.title_angle = 0.0
            if self.orientation == 'top':
                self._origin_point = array([overlay_component.x, overlay_component.y2])
                self._inside_vector = array([0.0, -1.0])
            else:
                self._origin_point = array([overlay_component.x, overlay_component.y])
                self._inside_vector = array([0.0, 1.0])
            if "right" in component_origin:
                screenlow, screenhigh = screenhigh, screenlow

        elif self.orientation in ('left', 'right'):
            self._major_axis_size = overlay_component.bounds[1]
            self._minor_axis_size = overlay_component.bounds[0]
            self._major_axis = array([0., 1.])
            self._title_orientation = array([-1., 0])
            if self.orientation == 'left':
                self._origin_point = array([overlay_component.x, overlay_component.y])
                self._inside_vector = array([1.0, 0.0])
                self.title_angle = 90.0
            else:
                self._origin_point = array([overlay_component.x2, overlay_component.y])
                self._inside_vector = array([-1.0, 0.0])
                self.title_angle = 270.0
            if "top" in component_origin:
                screenlow, screenhigh = screenhigh, screenlow

        if self.ensure_ticks_bounded:
            self._origin_point -= self._inside_vector*self.tick_in

        self._end_axis_point = abs(screenhigh-screenlow)*self._major_axis + self._origin_point
        self._axis_vector = self._end_axis_point - self._origin_point
        # This is the vector that represents one unit of data space in terms of screen space.
        self._axis_pixel_vector = self._axis_vector/sqrt(dot(self._axis_vector,self._axis_vector))
        return


    #------------------------------------------------------------------------
    # Event handlers
    #------------------------------------------------------------------------

    def _bounds_changed(self, old, new):
        super(PlotAxis, self)._bounds_changed(old, new)
        self._layout_needed = True
        self._invalidate()

    def _bounds_items_changed(self, event):
        super(PlotAxis, self)._bounds_items_changed(event)
        self._layout_needed = True
        self._invalidate()

    def _mapper_changed(self, old, new):
        if old is not None:
            old.on_trait_change(self.mapper_updated, "updated", remove=True)
        if new is not None:
            new.on_trait_change(self.mapper_updated, "updated")
        self._invalidate()

    def mapper_updated(self):
        """
        Event handler that is bound to this axis's mapper's **updated** event
        """
        self._invalidate()

    def _position_changed(self, old, new):
        super(PlotAxis, self)._position_changed(old, new)
        self._cache_valid = False

    def _position_items_changed(self, event):
        super(PlotAxis, self)._position_items_changed(event)
        self._cache_valid = False

    def _position_changed_for_component(self):
        self._cache_valid = False

    def _position_items_changed_for_component(self):
        self._cache_valid = False

    def _bounds_changed_for_component(self):
        self._cache_valid = False
        self._layout_needed = True

    def _bounds_items_changed_for_component(self):
        self._cache_valid = False
        self._layout_needed = True

    def _origin_changed_for_component(self):
        self._invalidate()

    def _updated_fired(self):
        """If the axis bounds changed, redraw."""
        self._cache_valid = False
        return

    def _invalidate(self):
        self._cache_valid = False
        self.invalidate_draw()
        if self.component:
            self.component.invalidate_draw()
        return

    def _component_changed(self):
        if self.mapper is not None:
            # If there is a mapper set, just leave it be.
            return

        # Try to pick the most appropriate mapper for our orientation
        # and what information we can glean from our component.
        attrmap = { "left": ("ymapper", "y_mapper", "value_mapper"),
                    "bottom": ("xmapper", "x_mapper", "index_mapper"), }
        attrmap["right"] = attrmap["left"]
        attrmap["top"] = attrmap["bottom"]

        component = self.component
        attr1, attr2, attr3 = attrmap[self.orientation]
        for attr in attrmap[self.orientation]:
            if hasattr(component, attr):
                self.mapper = getattr(component, attr)
                break

        # Keep our origin in sync with the component
        self.origin = getattr(component, 'origin', 'bottom left')
        return


    #------------------------------------------------------------------------
    # The following event handlers just invalidate our previously computed
    # Label instances and backbuffer if any of our visual attributes change.
    # TODO: refactor this stuff and the caching of contained objects (e.g. Label)
    #------------------------------------------------------------------------

    def _title_changed(self):
        self.invalidate_draw()
        if self.component:
            self.component.invalidate_draw()
        return

    def _anytrait_changed(self, name, old, new):
        """ For every trait that defines a visual attribute
            we just call _invalidate() when a change is made.
        """
        invalidate_traits = [
            'title_font',
            'title_spacing',
            'title_color',
            'tick_weight',
            'tick_color',
            'tick_label_font',
            'tick_label_color',
            'tick_label_rotate_angle',
            'tick_label_alignment',
            'tick_label_margin',
            'tick_label_offset',
            'tick_label_position',
            'tick_label_formatter',
            'tick_in',
            'tick_out',
            'tick_visible',
            'tick_interval',
            'tick_generator',
            'orientation',
            'origin',
            'axis_line_visible',
            'axis_line_color',
            'axis_line_weight',
            'axis_line_style',
            'small_haxis_style',
            'ensure_labels_bounded',
            'ensure_ticks_bounded',
        ]
        if name in invalidate_traits:
            self._invalidate()

    #------------------------------------------------------------------------
    # Persistence-related methods
    #------------------------------------------------------------------------

    def __getstate__(self):
        dont_pickle = [
            '_tick_list',
            '_tick_positions',
            '_tick_label_list',
            '_tick_label_positions',
            '_tick_label_bounding_boxes',
            '_major_axis_size',
            '_minor_axis_size',
            '_major_axis',
            '_title_orientation',
            '_title_angle',
            '_origin_point',
            '_inside_vector',
            '_axis_vector',
            '_axis_pixel_vector',
            '_end_axis_point',
            '_ticklabel_cache',
            '_cache_valid'
           ]

        state = super(PlotAxis,self).__getstate__()
        for key in dont_pickle:
            if state.has_key(key):
                del state[key]

        return state

    def __setstate__(self, state):
        super(PlotAxis,self).__setstate__(state)
        self._mapper_changed(None, self.mapper)
        self._reset_cache()
        self._cache_valid = False
        return
示例#4
0
class PlotAxis(AbstractOverlay):
    """
    The PlotAxis is a visual component that can be rendered on its own as
    a standalone component or attached as an overlay to another component.
    (To attach it as an overlay, set its **component** attribute.)

    When it is attached as an overlay, it draws into the padding around
    the component.
    """

    # The mapper that drives this axis.
    mapper = Instance(AbstractMapper)

    # The text of the axis title.
    title = Trait('', Str, Unicode)  #May want to add PlotLabel option

    # The font of the title.
    title_font = KivaFont('modern 12')

    # The spacing between the axis line and the title
    title_spacing = Trait('auto', 'auto', Float)

    # The color of the title.
    title_color = ColorTrait("black")

    # The thickness (in pixels) of each tick.
    tick_weight = Float(1.0)

    # The color of the ticks.
    tick_color = ColorTrait("black")

    # The font of the tick labels.
    tick_label_font = KivaFont('modern 10')

    # The color of the tick labels.
    tick_label_color = ColorTrait("black")

    # The rotation of the tick labels.  (Only multiples of 90 are supported)
    tick_label_rotate_angle = Float(0)

    # Whether to align to corners or edges (corner is better for 45 degree rotation)
    tick_label_alignment = Enum('edge', 'corner')

    # The margin around the tick labels.
    tick_label_margin = Int(2)

    # The distance of the tick label from the axis.
    tick_label_offset = Float(8.)

    # Whether the tick labels appear to the inside or the outside of the plot area
    tick_label_position = Enum("outside", "inside")

    # A callable that is passed the numerical value of each tick label and
    # that returns a string.
    tick_label_formatter = Callable(DEFAULT_TICK_FORMATTER)

    # The number of pixels by which the ticks extend into the plot area.
    tick_in = Int(5)

    # The number of pixels by which the ticks extend into the label area.
    tick_out = Int(5)

    # Are ticks visible at all?
    tick_visible = Bool(True)

    # The dataspace interval between ticks.
    tick_interval = Trait('auto', 'auto', Float)

    # A callable that implements the AbstractTickGenerator interface.
    tick_generator = Instance(AbstractTickGenerator)

    # The location of the axis relative to the plot.  This determines where
    # the axis title is located relative to the axis line.
    orientation = Enum("top", "bottom", "left", "right")

    # Is the axis line visible?
    axis_line_visible = Bool(True)

    # The color of the axis line.
    axis_line_color = ColorTrait("black")

    # The line thickness (in pixels) of the axis line.
    axis_line_weight = Float(1.0)

    # The dash style of the axis line.
    axis_line_style = LineStyle('solid')

    # A special version of the axis line that is more useful for geophysical
    # plots.
    small_haxis_style = Bool(False)

    # Does the axis ensure that its end labels fall within its bounding area?
    ensure_labels_bounded = Bool(False)

    # Does the axis prevent the ticks from being rendered outside its bounds?
    # This flag is off by default because the standard axis *does* render ticks
    # that encroach on the plot area.
    ensure_ticks_bounded = Bool(False)

    # Fired when the axis's range bounds change.
    updated = Event

    #------------------------------------------------------------------------
    # Override default values of inherited traits
    #------------------------------------------------------------------------

    # Background color (overrides AbstractOverlay). Axes usually let the color of
    # the container show through.
    bgcolor = ColorTrait("transparent")

    # Dimensions that the axis is resizable in (overrides PlotComponent).
    # Typically, axes are resizable in both dimensions.
    resizable = "hv"

    #------------------------------------------------------------------------
    # Private Traits
    #------------------------------------------------------------------------

    # Cached position calculations

    _tick_list = List  # These are caches of their respective positions
    _tick_positions = Any  #List
    _tick_label_list = Any
    _tick_label_positions = Any
    _tick_label_bounding_boxes = List
    _major_axis_size = Float
    _minor_axis_size = Float
    _major_axis = Array
    _title_orientation = Array
    _title_angle = Float
    _origin_point = Array
    _inside_vector = Array
    _axis_vector = Array
    _axis_pixel_vector = Array
    _end_axis_point = Array

    ticklabel_cache = List
    _cache_valid = Bool(False)

    #------------------------------------------------------------------------
    # Public methods
    #------------------------------------------------------------------------

    def __init__(self, component=None, **kwargs):
        # TODO: change this back to a factory in the instance trait some day
        self.tick_generator = DefaultTickGenerator()
        # Override init so that our component gets set last.  We want the
        # _component_changed() event handler to get run last.
        super(PlotAxis, self).__init__(**kwargs)
        if component is not None:
            self.component = component

    def invalidate(self):
        """ Invalidates the pre-computed layout and scaling data.
        """
        self._reset_cache()
        self.invalidate_draw()
        return

    def traits_view(self):
        """ Returns a View instance for use with Traits UI.  This method is
        called automatically be the Traits framework when .edit_traits() is
        invoked.
        """
        from axis_view import AxisView
        return AxisView

    #------------------------------------------------------------------------
    # PlotComponent and AbstractOverlay interface
    #------------------------------------------------------------------------

    def _do_layout(self, *args, **kw):
        """ Tells this component to do layout at a given size.

        Overrides Component.
        """
        if self.use_draw_order and self.component is not None:
            self._layout_as_overlay(*args, **kw)
        else:
            super(PlotAxis, self)._do_layout(*args, **kw)
        return

    def overlay(self, component, gc, view_bounds=None, mode='normal'):
        """ Draws this component overlaid on another component.

        Overrides AbstractOverlay.
        """
        if not self.visible:
            return
        self._draw_component(gc, view_bounds, mode, component)
        return

    def _draw_overlay(self, gc, view_bounds=None, mode='normal'):
        """ Draws the overlay layer of a component.

        Overrides PlotComponent.
        """
        self._draw_component(gc, view_bounds, mode)
        return

    def _draw_component(self,
                        gc,
                        view_bounds=None,
                        mode='normal',
                        component=None):
        """ Draws the component.

        This method is preserved for backwards compatibility. Overrides
        PlotComponent.
        """
        if not self.visible:
            return

        if not self._cache_valid:
            if component is not None:
                self._calculate_geometry_overlay(component)
            else:
                self._calculate_geometry()
            self._compute_tick_positions(gc, component)
            self._compute_labels(gc)

        with gc:
            # slight optimization: if we set the font correctly on the
            # base gc before handing it in to our title and tick labels,
            # their set_font() won't have to do any work.
            gc.set_font(self.tick_label_font)

            if self.axis_line_visible:
                self._draw_axis_line(gc, self._origin_point,
                                     self._end_axis_point)
            if self.title:
                self._draw_title(gc)

            self._draw_ticks(gc)
            self._draw_labels(gc)

        self._cache_valid = True
        return

    #------------------------------------------------------------------------
    # Private draw routines
    #------------------------------------------------------------------------

    def _layout_as_overlay(self, size=None, force=False):
        """ Lays out the axis as an overlay on another component.
        """
        if self.component is not None:
            if self.orientation in ("left", "right"):
                self.y = self.component.y
                self.height = self.component.height
                if self.orientation == "left":
                    self.width = self.component.padding_left
                    self.x = self.component.outer_x
                elif self.orientation == "right":
                    self.width = self.component.padding_right
                    self.x = self.component.x2 + 1
            else:
                self.x = self.component.x
                self.width = self.component.width
                if self.orientation == "bottom":
                    self.height = self.component.padding_bottom
                    self.y = self.component.outer_y
                elif self.orientation == "top":
                    self.height = self.component.padding_top
                    self.y = self.component.y2 + 1
        return

    def _draw_axis_line(self, gc, startpoint, endpoint):
        """ Draws the line for the axis.
        """
        with gc:
            gc.set_antialias(0)
            gc.set_line_width(self.axis_line_weight)
            gc.set_stroke_color(self.axis_line_color_)
            gc.set_line_dash(self.axis_line_style_)
            gc.move_to(*around(startpoint))
            gc.line_to(*around(endpoint))
            gc.stroke_path()
        return

    def _draw_title(self, gc, label=None, axis_offset=None):
        """ Draws the title for the axis.
        """
        if label is None:
            title_label = Label(text=self.title,
                                font=self.title_font,
                                color=self.title_color,
                                rotate_angle=self.title_angle)
        else:
            title_label = label

        # get the _rotated_ bounding box of the label
        tl_bounds = array(title_label.get_bounding_box(gc), float64)
        text_center_to_corner = -tl_bounds / 2.0
        # which axis are we moving away from the axis line along?
        axis_index = self._major_axis.argmin()

        if self.title_spacing != 'auto':
            axis_offset = self.title_spacing

        if (self.title_spacing) and (axis_offset is None):
            if not self.ticklabel_cache:
                axis_offset = 25
            else:
                axis_offset = max([
                    l._bounding_box[axis_index] for l in self.ticklabel_cache
                ]) * 1.3

        offset = (self._origin_point + self._end_axis_point) / 2
        axis_dist = self.tick_out + tl_bounds[axis_index] / 2.0 + axis_offset
        offset -= self._inside_vector * axis_dist
        offset += text_center_to_corner

        gc.translate_ctm(*offset)
        title_label.draw(gc)
        gc.translate_ctm(*(-offset))
        return

    def _draw_ticks(self, gc):
        """ Draws the tick marks for the axis.
        """
        if not self.tick_visible:
            return
        gc.set_stroke_color(self.tick_color_)
        gc.set_line_width(self.tick_weight)
        gc.set_antialias(False)
        gc.begin_path()
        tick_in_vector = self._inside_vector * self.tick_in
        tick_out_vector = self._inside_vector * self.tick_out
        for tick_pos in self._tick_positions:
            gc.move_to(*(tick_pos + tick_in_vector))
            gc.line_to(*(tick_pos - tick_out_vector))
        gc.stroke_path()
        return

    def _draw_labels(self, gc):
        """ Draws the tick labels for the axis.
        """
        # which axis are we moving away from the axis line along?
        axis_index = self._major_axis.argmin()

        inside_vector = self._inside_vector
        if self.tick_label_position == "inside":
            inside_vector = -inside_vector

        for i in range(len(self._tick_label_positions)):
            #We want a more sophisticated scheme than just 2 decimals all the time
            ticklabel = self.ticklabel_cache[i]
            tl_bounds = self._tick_label_bounding_boxes[i]

            #base_position puts the tick label at a point where the vector
            #extending from the tick mark inside 8 units
            #just touches the rectangular bounding box of the tick label.
            #Note: This is not necessarily optimal for non
            #horizontal/vertical axes.  More work could be done on this.

            base_position = self._tick_label_positions[i].copy()
            axis_dist = self.tick_label_offset + tl_bounds[axis_index] / 2.0
            base_position -= inside_vector * axis_dist
            base_position -= tl_bounds / 2.0

            if self.tick_label_alignment == 'corner':
                if self.orientation in ("top", "bottom"):
                    base_position[0] += tl_bounds[0] / 2.0
                elif self.orientation == "left":
                    base_position[1] -= tl_bounds[1] / 2.0
                elif self.orientation == "right":
                    base_position[1] += tl_bounds[1] / 2.0

            if self.ensure_labels_bounded:
                bound_idx = self._major_axis.argmax()
                if i == 0:
                    base_position[bound_idx] = max(
                        base_position[bound_idx],
                        self._origin_point[bound_idx])
                elif i == len(self._tick_label_positions) - 1:
                    base_position[bound_idx] = min(base_position[bound_idx],
                                                   self._end_axis_point[bound_idx] - \
                                                   tl_bounds[bound_idx])

            tlpos = around(base_position)
            gc.translate_ctm(*tlpos)
            ticklabel.draw(gc)
            gc.translate_ctm(*(-tlpos))
        return

    #------------------------------------------------------------------------
    # Private methods for computing positions and layout
    #------------------------------------------------------------------------

    def _reset_cache(self):
        """ Clears the cached tick positions, labels, and label positions.
        """
        self._tick_positions = []
        self._tick_label_list = []
        self._tick_label_positions = []
        return

    def _compute_tick_positions(self, gc, overlay_component=None):
        """ Calculates the positions for the tick marks.
        """
        if (self.mapper is None):
            self._reset_cache()
            self._cache_valid = True
            return

        datalow = self.mapper.range.low
        datahigh = self.mapper.range.high
        screenhigh = self.mapper.high_pos
        screenlow = self.mapper.low_pos
        if overlay_component is not None:
            origin = getattr(overlay_component, 'origin', 'bottom left')
            if self.orientation in ("top", "bottom"):
                if "right" in origin:
                    flip_from_gc = True
                else:
                    flip_from_gc = False
            elif self.orientation in ("left", "right"):
                if "top" in origin:
                    flip_from_gc = True
                else:
                    flip_from_gc = False

            if flip_from_gc:
                screenlow, screenhigh = screenhigh, screenlow

        if (datalow == datahigh) or (screenlow == screenhigh) or \
           (datalow in [inf, -inf]) or (datahigh in [inf, -inf]):
            self._reset_cache()
            self._cache_valid = True
            return

        if datalow > datahigh:
            raise RuntimeError, "DataRange low is greater than high; unable to compute axis ticks."

        if not self.tick_generator:
            return

        if hasattr(self.tick_generator, "get_ticks_and_labels"):
            # generate ticks and labels simultaneously
            tmp = self.tick_generator.get_ticks_and_labels(
                datalow, datahigh, screenlow, screenhigh)
            if len(tmp) == 0:
                tick_list = []
                labels = []
            else:
                tick_list, labels = tmp
            # compute the labels here
            self.ticklabel_cache = [Label(text=lab,
                                          font=self.tick_label_font,
                                          color=self.tick_label_color) \
                                    for lab in labels]
            self._tick_label_bounding_boxes = [array(ticklabel.get_bounding_box(gc), float64) \
                                               for ticklabel in self.ticklabel_cache]
        else:
            scale = 'log' if isinstance(self.mapper, LogMapper) else 'linear'
            if self.small_haxis_style:
                tick_list = array([datalow, datahigh])
            else:
                tick_list = array(
                    self.tick_generator.get_ticks(datalow,
                                                  datahigh,
                                                  datalow,
                                                  datahigh,
                                                  self.tick_interval,
                                                  use_endpoints=False,
                                                  scale=scale), float64)

        mapped_tick_positions = (array(self.mapper.map_screen(tick_list))-screenlow) / \
                                            (screenhigh-screenlow)
        self._tick_positions = around(array([self._axis_vector*tickpos + self._origin_point \
                                for tickpos in mapped_tick_positions]))
        self._tick_label_list = tick_list
        self._tick_label_positions = self._tick_positions
        return

    def _compute_labels(self, gc):
        """Generates the labels for tick marks.

        Waits for the cache to become invalid.
        """
        # tick labels are already computed
        if hasattr(self.tick_generator, "get_ticks_and_labels"):
            return

        formatter = self.tick_label_formatter

        def build_label(val):
            tickstring = formatter(val) if formatter is not None else str(val)
            return Label(text=tickstring,
                         font=self.tick_label_font,
                         color=self.tick_label_color,
                         rotate_angle=self.tick_label_rotate_angle,
                         margin=self.tick_label_margin)

        self.ticklabel_cache = [
            build_label(val) for val in self._tick_label_list
        ]
        self._tick_label_bounding_boxes = [
            array(ticklabel.get_bounding_box(gc), float)
            for ticklabel in self.ticklabel_cache
        ]
        return

    def _calculate_geometry(self):
        screenhigh = self.mapper.high_pos
        screenlow = self.mapper.low_pos

        if self.orientation in ('top', 'bottom'):
            self._major_axis_size = self.bounds[0]
            self._minor_axis_size = self.bounds[1]
            self._major_axis = array([1., 0.])
            self._title_orientation = array([0., 1.])
            self.title_angle = 0.0
            if self.orientation == 'top':
                self._origin_point = array(
                    self.position) + self._major_axis * screenlow
                self._inside_vector = array([0., -1.])
            else:  #self.oriention == 'bottom'
                self._origin_point = array(self.position) + array(
                    [0., self.bounds[1]]) + self._major_axis * screenlow
                self._inside_vector = array([0., 1.])

        elif self.orientation in ('left', 'right'):
            self._major_axis_size = self.bounds[1]
            self._minor_axis_size = self.bounds[0]
            self._major_axis = array([0., 1.])
            self._title_orientation = array([-1., 0])
            if self.orientation == 'left':
                self._origin_point = array(self.position) + array(
                    [self.bounds[0], 0.]) + self._major_axis * screenlow
                self._inside_vector = array([1., 0.])
                self.title_angle = 90.0
            else:  #self.orientation == 'right'
                self._origin_point = array(
                    self.position) + self._major_axis * screenlow
                self._inside_vector = array([-1., 0.])
                self.title_angle = 270.0

        if self.ensure_ticks_bounded:
            self._origin_point -= self._inside_vector * self.tick_in

        self._end_axis_point = (
            screenhigh - screenlow) * self._major_axis + self._origin_point
        self._axis_vector = self._end_axis_point - self._origin_point
        # This is the vector that represents one unit of data space in terms of screen space.
        self._axis_pixel_vector = self._axis_vector / sqrt(
            dot(self._axis_vector, self._axis_vector))
        return

    def _calculate_geometry_overlay(self, overlay_component=None):
        if overlay_component is None:
            overlay_component = self
        component_origin = getattr(overlay_component, "origin", 'bottom left')

        screenhigh = self.mapper.high_pos
        screenlow = self.mapper.low_pos

        if self.orientation in ('top', 'bottom'):
            self._major_axis_size = overlay_component.bounds[0]
            self._minor_axis_size = overlay_component.bounds[1]
            self._major_axis = array([1., 0.])
            self._title_orientation = array([0., 1.])
            self.title_angle = 0.0
            if self.orientation == 'top':
                self._origin_point = array(
                    [overlay_component.x, overlay_component.y2])
                self._inside_vector = array([0.0, -1.0])
            else:
                self._origin_point = array(
                    [overlay_component.x, overlay_component.y])
                self._inside_vector = array([0.0, 1.0])
            if "right" in component_origin:
                screenlow, screenhigh = screenhigh, screenlow

        elif self.orientation in ('left', 'right'):
            self._major_axis_size = overlay_component.bounds[1]
            self._minor_axis_size = overlay_component.bounds[0]
            self._major_axis = array([0., 1.])
            self._title_orientation = array([-1., 0])
            if self.orientation == 'left':
                self._origin_point = array(
                    [overlay_component.x, overlay_component.y])
                self._inside_vector = array([1.0, 0.0])
                self.title_angle = 90.0
            else:
                self._origin_point = array(
                    [overlay_component.x2, overlay_component.y])
                self._inside_vector = array([-1.0, 0.0])
                self.title_angle = 270.0
            if "top" in component_origin:
                screenlow, screenhigh = screenhigh, screenlow

        if self.ensure_ticks_bounded:
            self._origin_point -= self._inside_vector * self.tick_in

        self._end_axis_point = (
            screenhigh - screenlow) * self._major_axis + self._origin_point
        self._axis_vector = self._end_axis_point - self._origin_point
        # This is the vector that represents one unit of data space in terms of screen space.
        self._axis_pixel_vector = self._axis_vector / sqrt(
            dot(self._axis_vector, self._axis_vector))
        return

    #------------------------------------------------------------------------
    # Event handlers
    #------------------------------------------------------------------------

    def _bounds_changed(self, old, new):
        super(PlotAxis, self)._bounds_changed(old, new)
        self._layout_needed = True
        self._invalidate()

    def _bounds_items_changed(self, event):
        super(PlotAxis, self)._bounds_items_changed(event)
        self._layout_needed = True
        self._invalidate()

    def _mapper_changed(self, old, new):
        if old is not None:
            old.on_trait_change(self.mapper_updated, "updated", remove=True)
        if new is not None:
            new.on_trait_change(self.mapper_updated, "updated")
        self._invalidate()

    def mapper_updated(self):
        """
        Event handler that is bound to this axis's mapper's **updated** event
        """
        self._invalidate()

    def _position_changed(self, old, new):
        super(PlotAxis, self)._position_changed(old, new)
        self._cache_valid = False

    def _position_items_changed(self, event):
        super(PlotAxis, self)._position_items_changed(event)
        self._cache_valid = False

    def _position_changed_for_component(self):
        self._cache_valid = False

    def _position_items_changed_for_component(self):
        self._cache_valid = False

    def _bounds_changed_for_component(self):
        self._cache_valid = False
        self._layout_needed = True

    def _bounds_items_changed_for_component(self):
        self._cache_valid = False
        self._layout_needed = True

    def _origin_changed_for_component(self):
        self._invalidate()

    def _updated_fired(self):
        """If the axis bounds changed, redraw."""
        self._cache_valid = False
        return

    def _invalidate(self):
        self._cache_valid = False
        self.invalidate_draw()
        if self.component:
            self.component.invalidate_draw()
        return

    def _component_changed(self):
        if self.mapper is not None:
            # If there is a mapper set, just leave it be.
            return

        # Try to pick the most appropriate mapper for our orientation
        # and what information we can glean from our component.
        attrmap = {
            "left": ("ymapper", "y_mapper", "value_mapper"),
            "bottom": ("xmapper", "x_mapper", "index_mapper"),
        }
        attrmap["right"] = attrmap["left"]
        attrmap["top"] = attrmap["bottom"]

        component = self.component
        attr1, attr2, attr3 = attrmap[self.orientation]
        for attr in attrmap[self.orientation]:
            if hasattr(component, attr):
                self.mapper = getattr(component, attr)
                break
        return

    #------------------------------------------------------------------------
    # The following event handlers just invalidate our previously computed
    # Label instances and backbuffer if any of our visual attributes change.
    # TODO: refactor this stuff and the caching of contained objects (e.g. Label)
    #------------------------------------------------------------------------

    def _title_changed(self):
        self.invalidate_draw()
        if self.component:
            self.component.invalidate_draw()
        return

    def _anytrait_changed(self, name, old, new):
        """ For every trait that defines a visual attribute
            we just call _invalidate() when a change is made.
        """
        invalidate_traits = [
            'title_font',
            'title_spacing',
            'title_color',
            'tick_weight',
            'tick_color',
            'tick_label_font',
            'tick_label_color',
            'tick_label_rotate_angle',
            'tick_label_alignment',
            'tick_label_margin',
            'tick_label_offset',
            'tick_label_position',
            'tick_label_formatter',
            'tick_in',
            'tick_out',
            'tick_visible',
            'tick_interval',
            'tick_generator',
            'orientation',
            'axis_line_visible',
            'axis_line_color',
            'axis_line_weight',
            'axis_line_style',
            'small_haxis_style',
            'ensure_labels_bounded',
            'ensure_ticks_bounded',
        ]
        if name in invalidate_traits:
            self._invalidate()

    #------------------------------------------------------------------------
    # Persistence-related methods
    #------------------------------------------------------------------------

    def __getstate__(self):
        dont_pickle = [
            '_tick_list', '_tick_positions', '_tick_label_list',
            '_tick_label_positions', '_tick_label_bounding_boxes',
            '_major_axis_size', '_minor_axis_size', '_major_axis',
            '_title_orientation', '_title_angle', '_origin_point',
            '_inside_vector', '_axis_vector', '_axis_pixel_vector',
            '_end_axis_point', '_ticklabel_cache', '_cache_valid'
        ]

        state = super(PlotAxis, self).__getstate__()
        for key in dont_pickle:
            if state.has_key(key):
                del state[key]

        return state

    def __setstate__(self, state):
        super(PlotAxis, self).__setstate__(state)
        self._mapper_changed(None, self.mapper)
        self._reset_cache()
        self._cache_valid = False
        return
示例#5
0
class PlotGrid(AbstractOverlay):
    """ An overlay that represents a grid.

    A grid is a set of parallel lines, horizontal or vertical. You can use
    multiple grids with different settings for the horizontal and vertical
    lines in a plot.
    """

    #------------------------------------------------------------------------
    # Data-related traits
    #------------------------------------------------------------------------

    # The mapper (and associated range) that drive this PlotGrid.
    mapper = Instance(AbstractMapper)

    # The dataspace interval between grid lines.
    grid_interval = Trait('auto', 'auto', Float)

    # The dataspace value at which to start this grid.  If None, then
    # uses the mapper.range.low.
    data_min = Trait(None, None, Float)

    # The dataspace value at which to end this grid.  If None, then uses
    # the mapper.range.high.
    data_max = Trait(None, None, Float)

    # A callable that implements the AbstractTickGenerator Interface.
    tick_generator = Instance(AbstractTickGenerator)

    #------------------------------------------------------------------------
    # Layout traits
    #------------------------------------------------------------------------

    # The orientation of the grid lines.  "horizontal" means that the grid
    # lines are parallel to the X axis and the ticker and grid interval
    # refer to the Y axis.
    orientation = Enum('horizontal', 'vertical')

    # Draw the ticks starting at the end of the mapper range? If False, the
    # ticks are drawn starting at 0. This setting can be useful to keep the
    # grid from from "flashing" as the user resizes the plot area.
    flip_axis = Bool(False)

    # Optional specification of the grid bounds in the dimension transverse
    # to the ticking/gridding dimension, i.e. along the direction specified
    # by self.orientation.  If this is specified but transverse_mapper is
    # not specified, then there is no effect.
    #
    #   None : use self.bounds or self.component.bounds (if overlay)
    #   Tuple : (low, high) extents, used for every grid line
    #   Callable : Function that takes an array of dataspace grid ticks
    #              and returns either an array of shape (N,2) of (starts,ends)
    #              for each grid point or a single tuple (low, high)
    transverse_bounds = Trait(None, Tuple, Callable)

    # Mapper in the direction corresponding to self.orientation, i.e. transverse
    # to the direction of self.mapper.  This is used to compute the screen
    # position of transverse_bounds.  If this is not specified, then
    # transverse_bounds has no effect, and vice versa.
    transverse_mapper = Instance(AbstractMapper)

    # Dimensions that the grid is resizable in (overrides PlotComponent).
    resizable = "hv"

    #------------------------------------------------------------------------
    # Appearance traits
    #------------------------------------------------------------------------

    # The color of the grid lines.
    line_color = black_color_trait

    # The style (i.e., dash pattern) of the grid lines.
    line_style = LineStyle('solid')

    # The thickness, in pixels, of the grid lines.
    line_width = CInt(1)
    line_weight = Alias("line_width")

    # Default Traits UI View for modifying grid attributes.
    traits_view = GridView

    #------------------------------------------------------------------------
    # Private traits; mostly cached information
    #------------------------------------------------------------------------

    _cache_valid = Bool(False)
    _tick_list = Any
    _tick_positions = Any

    # An array (N,2) of start,end positions in the transverse direction
    # i.e. the direction corresponding to self.orientation
    _tick_extents = Any

    #_length = Float(0.0)

    #------------------------------------------------------------------------
    # Public methods
    #------------------------------------------------------------------------

    def __init__(self, **traits):
        # TODO: change this back to a factory in the instance trait some day
        self.tick_generator = DefaultTickGenerator()
        super(PlotGrid, self).__init__(**traits)
        self.bgcolor = "none"  #make sure we're transparent
        return

    @on_trait_change("bounds,bounds_items,position,position_items")
    def invalidate(self):
        """ Invalidate cached information about the grid.
        """
        self._reset_cache()
        return

    #------------------------------------------------------------------------
    # PlotComponent and AbstractOverlay interface
    #------------------------------------------------------------------------

    def do_layout(self, *args, **kw):
        """ Tells this component to do layout at a given size.

        Overrides PlotComponent.
        """
        if self.use_draw_order and self.component is not None:
            self._layout_as_overlay(*args, **kw)
        else:
            super(PlotGrid, self).do_layout(*args, **kw)
        return

    #------------------------------------------------------------------------
    # Private methods
    #------------------------------------------------------------------------

    def _do_layout(self):
        """ Performs a layout.

        Overrides PlotComponent.
        """
        return

    def _layout_as_overlay(self, size=None, force=False):
        """ Lays out the axis as an overlay on another component.
        """
        if self.component is not None:
            self.position = self.component.position
            self.bounds = self.component.bounds
        return

    def _reset_cache(self):
        """ Clears the cached tick positions.
        """
        self._tick_positions = array([], dtype=float)
        self._tick_extents = array([], dtype=float)
        self._cache_valid = False
        return

    def _compute_ticks(self, component=None):
        """ Calculates the positions for the grid lines.
        """
        if (self.mapper is None):
            self._reset_cache()
            self._cache_valid = True
            return

        if self.data_min is None:
            datalow = self.mapper.range.low
        else:
            datalow = self.data_min
        if self.data_max is None:
            datahigh = self.mapper.range.high
        else:
            datahigh = self.data_max

        # Map the low and high data points
        screenhigh = self.mapper.map_screen(datalow)
        screenlow = self.mapper.map_screen(datahigh)

        if (datalow == datahigh) or (screenlow == screenhigh) or \
           (datalow in [inf, -inf]) or (datahigh in [inf, -inf]):
            self._reset_cache()
            self._cache_valid = True
            return

        if component is None:
            component = self.component

        if component is not None:
            bounds = component.bounds
            position = component.position
        else:
            bounds = self.bounds
            position = self.position

        if isinstance(self.mapper, LogMapper):
            scale = 'log'
        else:
            scale = 'linear'

        ticks = self.tick_generator.get_ticks(datalow,
                                              datahigh,
                                              datalow,
                                              datahigh,
                                              self.grid_interval,
                                              use_endpoints=False,
                                              scale=scale)
        tick_positions = self.mapper.map_screen(array(ticks, float64))

        if self.orientation == 'horizontal':
            self._tick_positions = around(
                column_stack((zeros_like(tick_positions) + position[0],
                              tick_positions)))
        elif self.orientation == 'vertical':
            self._tick_positions = around(
                column_stack((tick_positions,
                              zeros_like(tick_positions) + position[1])))
        else:
            raise self.NotImplementedError

        # Compute the transverse direction extents
        self._tick_extents = zeros((len(ticks), 2), dtype=float)
        if self.transverse_bounds is None or self.transverse_mapper is None:
            # No mapping needed, just use the extents
            if self.orientation == 'horizontal':
                extents = (position[0], position[0] + bounds[0])
            elif self.orientation == 'vertical':
                extents = (position[1], position[1] + bounds[1])
            self._tick_extents[:] = extents
        elif callable(self.transverse_bounds):
            data_extents = self.transverse_bounds(ticks)
            tmapper = self.transverse_mapper
            if isinstance(data_extents, tuple):
                self._tick_extents[:] = tmapper.map_screen(
                    asarray(data_extents))
            else:
                extents = array([
                    tmapper.map_screen(data_extents[:, 0]),
                    tmapper.map_screen(data_extents[:, 1])
                ]).T
                self._tick_extents = extents
        else:
            # Already a tuple
            self._tick_extents[:] = self.transverse_mapper.map_screen(
                asarray(self.transverse_bounds))

        self._cache_valid = True

    def _draw_overlay(self, gc, view_bounds=None, mode='normal'):
        """ Draws the overlay layer of a component.

        Overrides PlotComponent.
        """
        self._draw_component(gc, view_bounds, mode)
        return

    def overlay(self, other_component, gc, view_bounds=None, mode="normal"):
        """ Draws this component overlaid on another component.

        Overrides AbstractOverlay.
        """
        if not self.visible:
            return
        self._compute_ticks(other_component)
        self._draw_component(gc, view_bounds, mode)
        self._cache_valid = False
        return

    def _draw_component(self, gc, view_bounds=None, mode="normal"):
        """ Draws the component.

        This method is preserved for backwards compatibility. Overrides
        PlotComponent.
        """
        # What we're really trying to do with a grid is plot contour lines in
        # the space of the plot.  In a rectangular plot, these will always be
        # straight lines.
        if not self.visible:
            return

        if not self._cache_valid:
            self._compute_ticks()

        if len(self._tick_positions) == 0:
            return

        with gc:
            gc.set_line_width(self.line_weight)
            gc.set_line_dash(self.line_style_)
            gc.set_stroke_color(self.line_color_)
            gc.set_antialias(False)

            if self.component is not None:
                gc.clip_to_rect(*(self.component.position +
                                  self.component.bounds))
            else:
                gc.clip_to_rect(*(self.position + self.bounds))

            gc.begin_path()
            if self.orientation == "horizontal":
                starts = self._tick_positions.copy()
                starts[:, 0] = self._tick_extents[:, 0]
                ends = self._tick_positions.copy()
                ends[:, 0] = self._tick_extents[:, 1]
            else:
                starts = self._tick_positions.copy()
                starts[:, 1] = self._tick_extents[:, 0]
                ends = self._tick_positions.copy()
                ends[:, 1] = self._tick_extents[:, 1]
            if self.flip_axis:
                starts, ends = ends, starts
            gc.line_set(starts, ends)
            gc.stroke_path()
        return

    def _mapper_changed(self, old, new):
        if old is not None:
            old.on_trait_change(self.mapper_updated, "updated", remove=True)
        if new is not None:
            new.on_trait_change(self.mapper_updated, "updated")
        self.invalidate()
        return

    def mapper_updated(self):
        """
        Event handler that is bound to this mapper's **updated** event.
        """
        self.invalidate()
        return

    def _position_changed_for_component(self):
        self.invalidate()

    def _position_items_changed_for_component(self):
        self.invalidate()

    def _bounds_changed_for_component(self):
        self.invalidate()

    def _bounds_items_changed_for_component(self):
        self.invalidate()

    #------------------------------------------------------------------------
    # Event handlers for visual attributes.  These mostly just call request_redraw()
    #------------------------------------------------------------------------

    @on_trait_change("visible,line_color,line_style,line_weight")
    def visual_attr_changed(self):
        """ Called when an attribute that affects the appearance of the grid
        is changed.
        """
        if self.component:
            self.component.invalidate_draw()
            self.component.request_redraw()
        else:
            self.invalidate_draw()
            self.request_redraw()

    def _grid_interval_changed(self):
        self.invalidate()
        self.visual_attr_changed()

    def _orientation_changed(self):
        self.invalidate()
        self.visual_attr_changed()
        return

    ### Persistence ###########################################################
    #_pickles = ("orientation", "line_color", "line_style", "line_weight",
    #            "grid_interval", "mapper")

    def __getstate__(self):
        state = super(PlotGrid, self).__getstate__()
        for key in [
                '_cache_valid', '_tick_list', '_tick_positions',
                '_tick_extents'
        ]:
            if state.has_key(key):
                del state[key]

        return state

    def _post_load(self):
        super(PlotGrid, self)._post_load()
        self._mapper_changed(None, self.mapper)
        self._reset_cache()
        self._cache_valid = False
        return
示例#6
0
 def __init__(self, **traits):
     # TODO: change this back to a factory in the instance trait some day
     self.tick_generator = DefaultTickGenerator()
     super(PlotGrid, self).__init__(**traits)
     self.bgcolor = "none"  #make sure we're transparent
     return
示例#7
0
class PlotGrid(AbstractOverlay):
    """ An overlay that represents a grid.

    A grid is a set of parallel lines, horizontal or vertical. You can use
    multiple grids with different settings for the horizontal and vertical
    lines in a plot.
    """

    #------------------------------------------------------------------------
    # Data-related traits
    #------------------------------------------------------------------------

    # The mapper (and associated range) that drive this PlotGrid.
    mapper = Instance(AbstractMapper)

    # The dataspace interval between grid lines.
    grid_interval = Trait('auto', 'auto', Float)

    # The dataspace value at which to start this grid.  If None, then
    # uses the mapper.range.low.
    data_min = Trait(None, None, Float)

    # The dataspace value at which to end this grid.  If None, then uses
    # the mapper.range.high.
    data_max = Trait(None, None, Float)

    # A callable that implements the AbstractTickGenerator Interface.
    tick_generator = Instance(AbstractTickGenerator)

    #------------------------------------------------------------------------
    # Layout traits
    #------------------------------------------------------------------------

    # The orientation of the grid lines.  "horizontal" means that the grid
    # lines are parallel to the X axis and the ticker and grid interval
    # refer to the Y axis.
    orientation = Enum('horizontal', 'vertical')

    # Draw the ticks starting at the end of the mapper range? If False, the
    # ticks are drawn starting at 0. This setting can be useful to keep the
    # grid from from "flashing" as the user resizes the plot area.
    flip_axis = Bool(False)

    # Optional specification of the grid bounds in the dimension transverse
    # to the ticking/gridding dimension, i.e. along the direction specified
    # by self.orientation.  If this is specified but transverse_mapper is
    # not specified, then there is no effect.
    #
    #   None : use self.bounds or self.component.bounds (if overlay)
    #   Tuple : (low, high) extents, used for every grid line
    #   Callable : Function that takes an array of dataspace grid ticks
    #              and returns either an array of shape (N,2) of (starts,ends)
    #              for each grid point or a single tuple (low, high)
    transverse_bounds = Trait(None, Tuple, Callable)

    # Mapper in the direction corresponding to self.orientation, i.e. transverse
    # to the direction of self.mapper.  This is used to compute the screen
    # position of transverse_bounds.  If this is not specified, then
    # transverse_bounds has no effect, and vice versa.
    transverse_mapper = Instance(AbstractMapper)

    # Dimensions that the grid is resizable in (overrides PlotComponent).
    resizable = "hv"

    #------------------------------------------------------------------------
    # Appearance traits
    #------------------------------------------------------------------------

    # The color of the grid lines.
    line_color = black_color_trait

    # The style (i.e., dash pattern) of the grid lines.
    line_style = LineStyle('solid')

    # The thickness, in pixels, of the grid lines.
    line_width = CInt(1)
    line_weight = Alias("line_width")

    # Default Traits UI View for modifying grid attributes.
    traits_view = GridView

    #------------------------------------------------------------------------
    # Private traits; mostly cached information
    #------------------------------------------------------------------------

    _cache_valid = Bool(False)
    _tick_list = Any
    _tick_positions = Any

    # An array (N,2) of start,end positions in the transverse direction
    # i.e. the direction corresponding to self.orientation
    _tick_extents = Any

    #_length = Float(0.0)


    #------------------------------------------------------------------------
    # Public methods
    #------------------------------------------------------------------------

    def __init__(self, **traits):
        # TODO: change this back to a factory in the instance trait some day
        self.tick_generator = DefaultTickGenerator()
        super(PlotGrid, self).__init__(**traits)
        self.bgcolor = "none" #make sure we're transparent
        return

    @on_trait_change("bounds,bounds_items,position,position_items")
    def invalidate(self):
        """ Invalidate cached information about the grid.
        """
        self._reset_cache()
        return


    #------------------------------------------------------------------------
    # PlotComponent and AbstractOverlay interface
    #------------------------------------------------------------------------

    def do_layout(self, *args, **kw):
        """ Tells this component to do layout at a given size.

        Overrides PlotComponent.
        """
        if self.use_draw_order and self.component is not None:
            self._layout_as_overlay(*args, **kw)
        else:
            super(PlotGrid, self).do_layout(*args, **kw)
        return

    #------------------------------------------------------------------------
    # Private methods
    #------------------------------------------------------------------------

    def _do_layout(self):
        """ Performs a layout.

        Overrides PlotComponent.
        """
        return

    def _layout_as_overlay(self, size=None, force=False):
        """ Lays out the axis as an overlay on another component.
        """
        if self.component is not None:
            self.position = self.component.position
            self.bounds = self.component.bounds
        return

    def _reset_cache(self):
        """ Clears the cached tick positions.
        """
        self._tick_positions = array([], dtype=float)
        self._tick_extents = array([], dtype=float)
        self._cache_valid = False
        return

    def _compute_ticks(self, component=None):
        """ Calculates the positions for the grid lines.
        """
        if (self.mapper is None):
            self._reset_cache()
            self._cache_valid = True
            return

        if self.data_min is None:
            datalow = self.mapper.range.low
        else:
            datalow = self.data_min
        if self.data_max is None:
            datahigh = self.mapper.range.high
        else:
            datahigh = self.data_max

        # Map the low and high data points
        screenhigh = self.mapper.map_screen(datalow)
        screenlow = self.mapper.map_screen(datahigh)

        if (datalow == datahigh) or (screenlow == screenhigh) or \
           (datalow in [inf, -inf]) or (datahigh in [inf, -inf]):
            self._reset_cache()
            self._cache_valid = True
            return

        if component is None:
            component = self.component

        if component is not None:
            bounds = component.bounds
            position = component.position
        else:
            bounds = self.bounds
            position = self.position

        if isinstance(self.mapper, LogMapper):
            scale = 'log'
        else:
            scale = 'linear'

        ticks = self.tick_generator.get_ticks(datalow, datahigh, datalow, datahigh,
                                              self.grid_interval, use_endpoints = False,
                                              scale=scale)
        tick_positions = self.mapper.map_screen(array(ticks, float64))

        if self.orientation == 'horizontal':
            self._tick_positions = around(column_stack((zeros_like(tick_positions) + position[0],
                                           tick_positions)))
        elif self.orientation == 'vertical':
            self._tick_positions = around(column_stack((tick_positions,
                                           zeros_like(tick_positions) + position[1])))
        else:
            raise self.NotImplementedError

        # Compute the transverse direction extents
        self._tick_extents = zeros((len(ticks), 2), dtype=float)
        if self.transverse_bounds is None or self.transverse_mapper is None:
            # No mapping needed, just use the extents
            if self.orientation == 'horizontal':
                extents = (position[0], position[0] + bounds[0])
            elif self.orientation == 'vertical':
                extents = (position[1], position[1] + bounds[1])
            self._tick_extents[:] = extents
        elif callable(self.transverse_bounds):
            data_extents = self.transverse_bounds(ticks)
            tmapper = self.transverse_mapper
            if isinstance(data_extents, tuple):
                self._tick_extents[:] = tmapper.map_screen(asarray(data_extents))
            else:
                extents = array([tmapper.map_screen(data_extents[:,0]),
                                 tmapper.map_screen(data_extents[:,1])]).T
                self._tick_extents = extents
        else:
            # Already a tuple
            self._tick_extents[:] = self.transverse_mapper.map_screen(asarray(self.transverse_bounds))

        self._cache_valid = True


    def _draw_overlay(self, gc, view_bounds=None, mode='normal'):
        """ Draws the overlay layer of a component.

        Overrides PlotComponent.
        """
        self._draw_component(gc, view_bounds, mode)
        return

    def overlay(self, other_component, gc, view_bounds=None, mode="normal"):
        """ Draws this component overlaid on another component.

        Overrides AbstractOverlay.
        """
        if not self.visible:
            return
        self._compute_ticks(other_component)
        self._draw_component(gc, view_bounds, mode)
        self._cache_valid = False
        return

    def _draw_component(self, gc, view_bounds=None, mode="normal"):
        """ Draws the component.

        This method is preserved for backwards compatibility. Overrides
        PlotComponent.
        """
        # What we're really trying to do with a grid is plot contour lines in
        # the space of the plot.  In a rectangular plot, these will always be
        # straight lines.
        if not self.visible:
            return

        if not self._cache_valid:
            self._compute_ticks()

        if len(self._tick_positions) == 0:
            return

        with gc:
            gc.set_line_width(self.line_weight)
            gc.set_line_dash(self.line_style_)
            gc.set_stroke_color(self.line_color_)
            gc.set_antialias(False)

            if self.component is not None:
                gc.clip_to_rect(*(self.component.position + self.component.bounds))
            else:
                gc.clip_to_rect(*(self.position + self.bounds))

            gc.begin_path()
            if self.orientation == "horizontal":
                starts = self._tick_positions.copy()
                starts[:,0] = self._tick_extents[:,0]
                ends = self._tick_positions.copy()
                ends[:,0] = self._tick_extents[:,1]
            else:
                starts = self._tick_positions.copy()
                starts[:,1] = self._tick_extents[:,0]
                ends = self._tick_positions.copy()
                ends[:,1] = self._tick_extents[:,1]
            if self.flip_axis:
                starts, ends = ends, starts
            gc.line_set(starts, ends)
            gc.stroke_path()
        return

    def _mapper_changed(self, old, new):
        if old is not None:
            old.on_trait_change(self.mapper_updated, "updated", remove=True)
        if new is not None:
            new.on_trait_change(self.mapper_updated, "updated")
        self.invalidate()
        return

    def mapper_updated(self):
        """
        Event handler that is bound to this mapper's **updated** event.
        """
        self.invalidate()
        return

    def _position_changed_for_component(self):
        self.invalidate()

    def _position_items_changed_for_component(self):
        self.invalidate()

    def _bounds_changed_for_component(self):
        self.invalidate()

    def _bounds_items_changed_for_component(self):
        self.invalidate()

    #------------------------------------------------------------------------
    # Event handlers for visual attributes.  These mostly just call request_redraw()
    #------------------------------------------------------------------------

    @on_trait_change("visible,line_color,line_style,line_weight")
    def visual_attr_changed(self):
        """ Called when an attribute that affects the appearance of the grid
        is changed.
        """
        if self.component:
            self.component.invalidate_draw()
            self.component.request_redraw()
        else:
            self.invalidate_draw()
            self.request_redraw()

    def _grid_interval_changed(self):
        self.invalidate()
        self.visual_attr_changed()

    def _orientation_changed(self):
        self.invalidate()
        self.visual_attr_changed()
        return


    ### Persistence ###########################################################
    #_pickles = ("orientation", "line_color", "line_style", "line_weight",
    #            "grid_interval", "mapper")

    def __getstate__(self):
        state = super(PlotGrid,self).__getstate__()
        for key in ['_cache_valid', '_tick_list', '_tick_positions', '_tick_extents']:
            if state.has_key(key):
                del state[key]

        return state

    def _post_load(self):
        super(PlotGrid, self)._post_load()
        self._mapper_changed(None, self.mapper)
        self._reset_cache()
        self._cache_valid = False
        return
示例#8
0
 def __init__(self, **traits):
     # TODO: change this back to a factory in the instance trait some day
     self.tick_generator = DefaultTickGenerator()
     super(PlotGrid, self).__init__(**traits)
     self.bgcolor = "none" #make sure we're transparent
     return