Exemplo n.º 1
0
class Stroke(HasStrictTraits):
    #: Path vertices; shape => (# of points) x 2
    vertices = ArrayOrNone(shape=(None, 2), dtype=np.float64)

    #: Codes for path vertices
    codes = List(String)

    #: Median vertices; shape => # of points x 2
    medians = ArrayOrNone(shape=(None, 2), dtype=np.float64)

    @classmethod
    def from_hanzi_data(cls, path_data, median_data, median_offset=0):
        vertices, codes = svg_parse(path_data)
        median_data[0] = adjust_point_by(median_data[:2], 0, -median_offset)
        traits = {'vertices': vertices, 'codes': codes, 'medians': median_data}
        return cls(**traits)

    def to_mpl_path(self):
        commands = {
            'M': (Path.MOVETO, ),
            'L': (Path.LINETO, ),
            'Q': (Path.CURVE3, ) * 2,
            'C': (Path.CURVE4, ) * 3,
            'Z': (Path.CLOSEPOLY, )
        }
        mpl_codes = []
        for code in self.codes:
            mpl_codes.extend(commands[code])

        return Path(self.vertices, mpl_codes)
Exemplo n.º 2
0
    class Foo(HasTraits):
        maybe_array = ArrayOrNone

        maybe_float_array = ArrayOrNone(dtype=float)

        maybe_two_d_array = ArrayOrNone(shape=(None, None))

        maybe_array_with_default = ArrayOrNone(value=[1, 2, 3])

        maybe_array_no_compare = ArrayOrNone(comparison_mode=NO_COMPARE)
Exemplo n.º 3
0
class FiducialsSource(HasTraits):
    """Expose points of a given fiducials fif file.

    Parameters
    ----------
    file : File
        Path to a fif file with fiducials (*.fif).

    Attributes
    ----------
    points : Array, shape = (n_points, 3)
        Fiducials file points.
    """

    file = File(filter=[fid_wildcard])
    fname = Property(depends_on='file')
    points = Property(ArrayOrNone, depends_on='file')
    mni_points = ArrayOrNone(float, shape=(3, 3))

    def _get_fname(self):
        return op.basename(self.file)

    @cached_property
    def _get_points(self):
        if not op.exists(self.file):
            return self.mni_points  # can be None
        try:
            return _fiducial_coords(*read_fiducials(self.file))
        except Exception as err:
            error(
                None, "Error reading fiducials from %s: %s (See terminal "
                "for more information)" % (self.fname, str(err)),
                "Error Reading Fiducials")
            self.reset_traits(['file'])
            raise
Exemplo n.º 4
0
class Gradient(HasStrictTraits):
    """ A color gradient. """

    #: The sequence of color stops for the gradient.
    stops = List(Instance(ColorStop))

    #: A trait which fires when the gradient is updated.
    updated = Event()

    #: A temporary cache for the stop array.
    _array_cache = ArrayOrNone()

    def to_array(self):
        """ Return a sorted list of stop arrays.

        This is the raw form of the stops required by Kiva.

        Returns
        -------
        stop_array_list : arrays
            A list of array of (offset, r, b, b, a) values corresponding to
            the color stops.  This array should not be mutated.
        """
        if self._array_cache is None:
            self._array_cache = array([
                stop.to_array()
                for stop in sorted(self.stops, key=attrgetter("offset"))
            ])
        return self._array_cache

    @observe("stops.items.updated")
    def observe_stops(self, event):
        self._array_cache = None
        self.updated = True

    def _stops_default(self):
        return [
            ColorStop(offset=0.0, color="white"),
            ColorStop(offset=1.0, color="black"),
        ]
Exemplo n.º 5
0
class RectangularSelection(LassoSelection):
    """ A lasso selection tool whose selection shape is rectangular
    """

    #: The first click. This represents a corner of the rectangle.
    first_corner = ArrayOrNone(shape=(2, ))

    def selecting_mouse_move(self, event):
        """ This function is the same as the super except that it injects
        `_make_rectangle` as the `_active_selection` assignment.
        """
        # Translate the event's location to be relative to this container
        xform = self.component.get_event_transform(event)
        event.push_transform(xform, caller=self)
        new_point = self._map_data(np.array((event.x, event.y)))
        if self.first_corner is None:
            self.first_corner = new_point
        self._active_selection = self._make_rectangle(self.first_corner,
                                                      new_point)
        self.updated = True
        if self.incremental_select:
            self._update_selection()
        # Report None for the previous selections
        self.trait_property_changed("disjoint_selections", None)

    def selecting_mouse_up(self, event):
        super(RectangularSelection, self).selecting_mouse_up(event)
        # Clear the first click
        self.first_corner = None

    def _make_rectangle(self, p1, p2):
        """ Makes an array that represents that path that follows the
        corner points of the rectangle with two corners p1 and p2:
            *-----p2
            |     |
            p1----*
        """
        return np.array([p1, [p1[0], p2[1]], p2, [p2[0], p1[1]]])
 class Bar(HasTraits):
     bad_array = ArrayOrNone(shape=(None, None), value=[1, 2, 3])
 class Bar(HasTraits):
     unsafe_f32 = ArrayOrNone(dtype="float32")
     safe_f32 = ArrayOrNone(dtype="float32", casting="safe")
        class FooBar(HasTraits):
            foo = ArrayOrNone(value=test_default)

            bar = ArrayOrNone(value=test_default)
Exemplo n.º 9
0
class LabelAxis(PlotAxis):
    """ An axis whose ticks are labeled with text instead of numbers.
    """

    #: List of labels to use on tick marks.
    labels = List(Str)

    #: The angle of rotation of the label. Only multiples of 90 are supported.
    label_rotation = Float(0)

    #: List of indices of ticks
    positions = ArrayOrNone()

    def _compute_tick_positions(self, gc, component=None):
        """ Calculates the positions for the tick marks.

        Overrides PlotAxis.
        """
        if (self.mapper is None):
            self._reset_cache()
            self._cache_valid = True
            return

        datalow = self.mapper.range.low
        datahigh = self.mapper.range.high
        screenhigh = self.mapper.high_pos
        screenlow = self.mapper.low_pos

        if (datalow == datahigh) or (screenlow == screenhigh) or \
           (datalow in [inf, -inf]) or (datahigh in [inf, -inf]):
            self._reset_cache()
            self._cache_valid = True
            return

        if not self.tick_generator:
            return

        # Get a set of ticks from the tick generator.
        tick_list = array(
            self.tick_generator.get_ticks(datalow, datahigh, datalow, datahigh,
                                          self.tick_interval), float64)

        # Find all the positions in the current range.
        pos_index = []
        pos = []
        pos_min = None
        pos_max = None
        for i, position in enumerate(self.positions):
            if datalow <= position <= datahigh:
                pos_max = max(position,
                              pos_max) if pos_max is not None else position
                pos_min = min(position,
                              pos_min) if pos_min is not None else position
                pos_index.append(i)
                pos.append(position)
        if len(pos_index) == 0:
            # No positions currently visible.
            self._tick_positions = []
            self._tick_label_positions = []
            self._tick_label_list = []
            return

        # Use the ticks generated by the tick generator as a guide for selecting
        # the positions to be displayed.
        tick_indices = unique(searchsorted(pos, tick_list))
        tick_indices = tick_indices[tick_indices < len(pos)]
        tick_positions = take(pos, tick_indices)
        self._tick_label_list = take(self.labels, take(pos_index,
                                                       tick_indices))

        if datalow > datahigh:
            raise RuntimeError(
                "DataRange low is greater than high; unable to compute axis ticks."
            )

        mapped_label_positions = [((self.mapper.map_screen(pos)-screenlow) / \
                                    (screenhigh-screenlow)) for pos in tick_positions]
        self._tick_positions = [self._axis_vector*tickpos + self._origin_point \
                                 for tickpos in mapped_label_positions]
        self._tick_label_positions = self._tick_positions
        return

    def _compute_labels(self, gc):
        """Generates the labels for tick marks.

        Overrides PlotAxis.
        """
        try:
            self.ticklabel_cache = []
            for text in self._tick_label_list:
                ticklabel = Label(text=text,
                                  font=self.tick_label_font,
                                  color=self.tick_label_color,
                                  rotate_angle=self.label_rotation)
                self.ticklabel_cache.append(ticklabel)

            self._tick_label_bounding_boxes = [
                array(ticklabel.get_bounding_box(gc), float64)
                for ticklabel in self.ticklabel_cache
            ]
        except:
            print_exc()
        return
Exemplo n.º 10
0
class MarkerPointDest(MarkerPoints):  # noqa: D401
    """MarkerPoints subclass that serves for derived points."""

    src1 = Instance(MarkerPointSource)
    src2 = Instance(MarkerPointSource)

    name = Property(Str, depends_on='src1.name,src2.name')
    dir = Property(Str, depends_on='src1.dir,src2.dir')

    points = Property(ArrayOrNone(float, (5, 3)),
                      depends_on=[
                          'method', 'src1.points', 'src1.use', 'src2.points',
                          'src2.use'
                      ])
    enabled = Property(Bool, depends_on=['points'])

    method = Enum('Transform',
                  'Average',
                  desc="Transform: estimate a rotation"
                  "/translation from mrk1 to mrk2; Average: use the average "
                  "of the mrk1 and mrk2 coordinates for each point.")

    view = View(
        VGroup(Item('method', style='custom'),
               Item('save_as', enabled_when='can_save', show_label=False)))

    @cached_property
    def _get_dir(self):
        return self.src1.dir

    @cached_property
    def _get_name(self):
        n1 = self.src1.name
        n2 = self.src2.name

        if not n1:
            if n2:
                return n2
            else:
                return ''
        elif not n2:
            return n1

        if n1 == n2:
            return n1

        i = 0
        l1 = len(n1) - 1
        l2 = len(n1) - 2
        while n1[i] == n2[i]:
            if i == l1:
                return n1
            elif i == l2:
                return n2

            i += 1

        return n1[:i]

    @cached_property
    def _get_enabled(self):
        return np.any(self.points)

    @cached_property
    def _get_points(self):
        # in case only one or no source is enabled
        if not (self.src1 and self.src1.enabled):
            if (self.src2 and self.src2.enabled):
                return self.src2.points
            else:
                return np.zeros((5, 3))
        elif not (self.src2 and self.src2.enabled):
            return self.src1.points

        # Average method
        if self.method == 'Average':
            if len(np.union1d(self.src1.use, self.src2.use)) < 5:
                error(None, "Need at least one source for each point.",
                      "Marker Average Error")
                return np.zeros((5, 3))

            pts = (self.src1.points + self.src2.points) / 2.
            for i in np.setdiff1d(self.src1.use, self.src2.use):
                pts[i] = self.src1.points[i]
            for i in np.setdiff1d(self.src2.use, self.src1.use):
                pts[i] = self.src2.points[i]

            return pts

        # Transform method
        idx = np.intersect1d(np.array(self.src1.use),
                             np.array(self.src2.use),
                             assume_unique=True)
        if len(idx) < 3:
            error(None, "Need at least three shared points for trans"
                  "formation.", "Marker Interpolation Error")
            return np.zeros((5, 3))

        src_pts = self.src1.points[idx]
        tgt_pts = self.src2.points[idx]
        est = fit_matched_points(src_pts, tgt_pts, out='params')
        rot = np.array(est[:3]) / 2.
        tra = np.array(est[3:]) / 2.

        if len(self.src1.use) == 5:
            trans = np.dot(translation(*tra), rotation(*rot))
            pts = apply_trans(trans, self.src1.points)
        elif len(self.src2.use) == 5:
            trans = np.dot(translation(*-tra), rotation(*-rot))
            pts = apply_trans(trans, self.src2.points)
        else:
            trans1 = np.dot(translation(*tra), rotation(*rot))
            pts = apply_trans(trans1, self.src1.points)
            trans2 = np.dot(translation(*-tra), rotation(*-rot))
            for i in np.setdiff1d(self.src2.use, self.src1.use):
                pts[i] = apply_trans(trans2, self.src2.points[i])

        return pts
Exemplo n.º 11
0
class PlotAxis(AbstractOverlay):
    """
    The PlotAxis is a visual component that can be rendered on its own as
    a standalone component or attached as an overlay to another component.
    (To attach it as an overlay, set its **component** attribute.)

    When it is attached as an overlay, it draws into the padding around
    the component.
    """

    #: The mapper that drives this axis.
    mapper = Instance(AbstractMapper)

    #: Keep an origin for plots that aren't attached to a component
    origin = Enum("bottom left", "top left", "bottom right", "top right")

    #: The text of the axis title.
    title = Trait('', Str, Unicode) #May want to add PlotLabel option

    #: The font of the title.
    title_font = KivaFont('modern 12')

    #: The spacing between the axis line and the title
    title_spacing = Trait('auto', 'auto', Float)

    #: The color of the title.
    title_color = ColorTrait("black")

    #: The angle of the title, in degrees, from horizontal line
    title_angle = Float(0.)

    #: The thickness (in pixels) of each tick.
    tick_weight = Float(1.0)

    #: The color of the ticks.
    tick_color = ColorTrait("black")

    #: The font of the tick labels.
    tick_label_font = KivaFont('modern 10')

    #: The color of the tick labels.
    tick_label_color = ColorTrait("black")

    #: The rotation of the tick labels.
    tick_label_rotate_angle = Float(0)

    #: Whether to align to corners or edges (corner is better for 45 degree rotation)
    tick_label_alignment = Enum('edge', 'corner')

    #: The margin around the tick labels.
    tick_label_margin = Int(2)

    #: The distance of the tick label from the axis.
    tick_label_offset = Float(8.)

    #: Whether the tick labels appear to the inside or the outside of the plot area
    tick_label_position = Enum("outside", "inside")

    #: A callable that is passed the numerical value of each tick label and
    #: that returns a string.
    tick_label_formatter = Callable(DEFAULT_TICK_FORMATTER)

    #: The number of pixels by which the ticks extend into the plot area.
    tick_in = Int(5)

    #: The number of pixels by which the ticks extend into the label area.
    tick_out = Int(5)

    #: Are ticks visible at all?
    tick_visible = Bool(True)

    #: The dataspace interval between ticks.
    tick_interval = Trait('auto', 'auto', Float)

    #: A callable that implements the AbstractTickGenerator interface.
    tick_generator = Instance(AbstractTickGenerator)

    #: The location of the axis relative to the plot.  This determines where
    #: the axis title is located relative to the axis line.
    orientation = Enum("top", "bottom", "left", "right")

    #: Is the axis line visible?
    axis_line_visible = Bool(True)

    #: The color of the axis line.
    axis_line_color = ColorTrait("black")

    #: The line thickness (in pixels) of the axis line.
    axis_line_weight = Float(1.0)

    #: The dash style of the axis line.
    axis_line_style = LineStyle('solid')

    #: A special version of the axis line that is more useful for geophysical
    #: plots.
    small_haxis_style = Bool(False)

    #: Does the axis ensure that its end labels fall within its bounding area?
    ensure_labels_bounded = Bool(False)

    #: Does the axis prevent the ticks from being rendered outside its bounds?
    #: This flag is off by default because the standard axis *does* render ticks
    #: that encroach on the plot area.
    ensure_ticks_bounded = Bool(False)

    #: Fired when the axis's range bounds change.
    updated = Event

    #------------------------------------------------------------------------
    # Override default values of inherited traits
    #------------------------------------------------------------------------

    #: Background color (overrides AbstractOverlay). Axes usually let the color of
    #: the container show through.
    bgcolor = ColorTrait("transparent")

    #: Dimensions that the axis is resizable in (overrides PlotComponent).
    #: Typically, axes are resizable in both dimensions.
    resizable = "hv"

    #------------------------------------------------------------------------
    # Private Traits
    #------------------------------------------------------------------------

    # Cached position calculations

    _tick_list = List  # These are caches of their respective positions
    _tick_positions = ArrayOrNone()
    _tick_label_list = ArrayOrNone()
    _tick_label_positions = ArrayOrNone()
    _tick_label_bounding_boxes = List
    _major_axis_size = Float
    _minor_axis_size = Float
    _major_axis = Array
    _title_orientation = Array
    _title_angle = Float
    _origin_point = Array
    _inside_vector = Array
    _axis_vector = Array
    _axis_pixel_vector = Array
    _end_axis_point = Array


    ticklabel_cache = List
    _cache_valid = Bool(False)


    #------------------------------------------------------------------------
    # Public methods
    #------------------------------------------------------------------------

    def __init__(self, component=None, **kwargs):
        # TODO: change this back to a factory in the instance trait some day
        self.tick_generator = DefaultTickGenerator()
        # Override init so that our component gets set last.  We want the
        # _component_changed() event handler to get run last.
        super(PlotAxis, self).__init__(**kwargs)
        if component is not None:
            self.component = component

    def invalidate(self):
        """ Invalidates the pre-computed layout and scaling data.
        """
        self._reset_cache()
        self.invalidate_draw()
        return

    def traits_view(self):
        """ Returns a View instance for use with Traits UI.  This method is
        called automatically be the Traits framework when .edit_traits() is
        invoked.
        """
        from .axis_view import AxisView
        return AxisView


    #------------------------------------------------------------------------
    # PlotComponent and AbstractOverlay interface
    #------------------------------------------------------------------------

    def _do_layout(self, *args, **kw):
        """ Tells this component to do layout at a given size.

        Overrides Component.
        """
        if self.use_draw_order and self.component is not None:
            self._layout_as_overlay(*args, **kw)
        else:
            super(PlotAxis, self)._do_layout(*args, **kw)
        return

    def overlay(self, component, gc, view_bounds=None, mode='normal'):
        """ Draws this component overlaid on another component.

        Overrides AbstractOverlay.
        """
        if not self.visible:
            return
        self._draw_component(gc, view_bounds, mode, component)
        return

    def _draw_overlay(self, gc, view_bounds=None, mode='normal'):
        """ Draws the overlay layer of a component.

        Overrides PlotComponent.
        """
        self._draw_component(gc, view_bounds, mode)
        return

    def _draw_component(self, gc, view_bounds=None, mode='normal', component=None):
        """ Draws the component.

        This method is preserved for backwards compatibility. Overrides
        PlotComponent.
        """
        if not self.visible:
            return

        if not self._cache_valid:
            if component is not None:
                self._calculate_geometry_overlay(component)
            else:
                self._calculate_geometry()
            self._compute_tick_positions(gc, component)
            self._compute_labels(gc)

        with gc:
            # slight optimization: if we set the font correctly on the
            # base gc before handing it in to our title and tick labels,
            # their set_font() won't have to do any work.
            gc.set_font(self.tick_label_font)

            if self.axis_line_visible:
                self._draw_axis_line(gc, self._origin_point, self._end_axis_point)
            if self.title:
                self._draw_title(gc)

            self._draw_ticks(gc)
            self._draw_labels(gc)

        self._cache_valid = True
        return


    #------------------------------------------------------------------------
    # Private draw routines
    #------------------------------------------------------------------------

    def _layout_as_overlay(self, size=None, force=False):
        """ Lays out the axis as an overlay on another component.
        """
        if self.component is not None:
            if self.orientation in ("left", "right"):
                self.y = self.component.y
                self.height = self.component.height
                if self.orientation == "left":
                    self.width = self.component.padding_left
                    self.x = self.component.outer_x
                elif self.orientation == "right":
                    self.width = self.component.padding_right
                    self.x = self.component.x2 + 1
            else:
                self.x = self.component.x
                self.width = self.component.width
                if self.orientation == "bottom":
                    self.height = self.component.padding_bottom
                    self.y = self.component.outer_y
                elif self.orientation == "top":
                    self.height = self.component.padding_top
                    self.y = self.component.y2 + 1
        return

    def _draw_axis_line(self, gc, startpoint, endpoint):
        """ Draws the line for the axis.
        """
        with gc:
            gc.set_antialias(0)
            gc.set_line_width(self.axis_line_weight)
            gc.set_stroke_color(self.axis_line_color_)
            gc.set_line_dash(self.axis_line_style_)
            gc.move_to(*around(startpoint))
            gc.line_to(*around(endpoint))
            gc.stroke_path()
        return


    def _draw_title(self, gc, label=None, axis_offset=None):
        """ Draws the title for the axis.
        """
        if label is None:
            title_label = Label(text=self.title,
                                font=self.title_font,
                                color=self.title_color,
                                rotate_angle=self.title_angle)
        else:
            title_label = label

        # get the _rotated_ bounding box of the label
        tl_bounds = array(title_label.get_bounding_box(gc), float64)
        text_center_to_corner = -tl_bounds/2.0
        # which axis are we moving away from the axis line along?
        axis_index = self._major_axis.argmin()

        if self.title_spacing != 'auto':
            axis_offset = self.title_spacing

        if (self.title_spacing) and (axis_offset is None ):
            if not self.ticklabel_cache:
                axis_offset = 25
            else:
                axis_offset = max([l._bounding_box[axis_index] for l in self.ticklabel_cache]) * 1.3

        offset = (self._origin_point+self._end_axis_point)/2
        axis_dist = self.tick_out + tl_bounds[axis_index]/2.0 + axis_offset
        offset -= self._inside_vector * axis_dist
        offset += text_center_to_corner

        gc.translate_ctm(*offset)
        title_label.draw(gc)
        gc.translate_ctm(*(-offset))
        return


    def _draw_ticks(self, gc):
        """ Draws the tick marks for the axis.
        """
        if not self.tick_visible:
            return
        gc.set_stroke_color(self.tick_color_)
        gc.set_line_width(self.tick_weight)
        gc.set_antialias(False)
        gc.begin_path()
        tick_in_vector = self._inside_vector*self.tick_in
        tick_out_vector = self._inside_vector*self.tick_out
        for tick_pos in self._tick_positions:
            gc.move_to(*(tick_pos + tick_in_vector))
            gc.line_to(*(tick_pos - tick_out_vector))
        gc.stroke_path()
        return

    def _draw_labels(self, gc):
        """ Draws the tick labels for the axis.
        """
        # which axis are we moving away from the axis line along?
        axis_index = self._major_axis.argmin()

        inside_vector = self._inside_vector
        if self.tick_label_position == "inside":
            inside_vector = -inside_vector

        for i in range(len(self._tick_label_positions)):
            #We want a more sophisticated scheme than just 2 decimals all the time
            ticklabel = self.ticklabel_cache[i]
            tl_bounds = self._tick_label_bounding_boxes[i]

            #base_position puts the tick label at a point where the vector
            #extending from the tick mark inside 8 units
            #just touches the rectangular bounding box of the tick label.
            #Note: This is not necessarily optimal for non
            #horizontal/vertical axes.  More work could be done on this.

            base_position = self._tick_label_positions[i].copy()
            axis_dist = self.tick_label_offset + tl_bounds[axis_index]/2.0
            base_position -= inside_vector * axis_dist
            base_position -= tl_bounds/2.0

            if self.tick_label_alignment == 'corner':
                if self.orientation in ("top", "bottom"):
                    base_position[0] += tl_bounds[0]/2.0
                elif self.orientation == "left":
                    base_position[1] -= tl_bounds[1]/2.0
                elif self.orientation == "right":
                    base_position[1] += tl_bounds[1]/2.0

            if self.ensure_labels_bounded:
                bound_idx = self._major_axis.argmax()
                if i == 0:
                    base_position[bound_idx] = max(base_position[bound_idx],
                                                   self._origin_point[bound_idx])
                elif i == len(self._tick_label_positions)-1:
                    base_position[bound_idx] = min(base_position[bound_idx],
                                                   self._end_axis_point[bound_idx] - \
                                                   tl_bounds[bound_idx])

            tlpos = around(base_position)
            gc.translate_ctm(*tlpos)
            ticklabel.draw(gc)
            gc.translate_ctm(*(-tlpos))
        return


    #------------------------------------------------------------------------
    # Private methods for computing positions and layout
    #------------------------------------------------------------------------

    def _reset_cache(self):
        """ Clears the cached tick positions, labels, and label positions.
        """
        self._tick_positions = []
        self._tick_label_list = []
        self._tick_label_positions = []
        return

    def _compute_tick_positions(self, gc, overlay_component=None):
        """ Calculates the positions for the tick marks.
        """
        if (self.mapper is None):
            self._reset_cache()
            self._cache_valid = True
            return

        datalow = self.mapper.range.low
        datahigh = self.mapper.range.high
        screenhigh = self.mapper.high_pos
        screenlow = self.mapper.low_pos
        if overlay_component is not None:
            origin = getattr(overlay_component, 'origin', 'bottom left')
        else:
            origin = self.origin
        if self.orientation in ("top", "bottom"):
            if "right" in origin:
                flip_from_gc = True
            else:
                flip_from_gc = False
        elif self.orientation in ("left", "right"):
            if "top" in origin:
                flip_from_gc = True
            else:
                flip_from_gc = False
        if flip_from_gc:
            screenlow, screenhigh = screenhigh, screenlow

        if (datalow == datahigh) or (screenlow == screenhigh) or \
           (datalow in [inf, -inf]) or (datahigh in [inf, -inf]):
            self._reset_cache()
            self._cache_valid = True
            return

        if datalow > datahigh:
            raise RuntimeError("DataRange low is greater than high; unable to compute axis ticks.")

        if not self.tick_generator:
            return

        if hasattr(self.tick_generator, "get_ticks_and_labels"):
            # generate ticks and labels simultaneously
            tmp = self.tick_generator.get_ticks_and_labels(datalow, datahigh,
                                                screenlow, screenhigh)
            if len(tmp) == 0:
                tick_list = []
                labels = []
            else:
                tick_list, labels = tmp
            # compute the labels here
            self.ticklabel_cache = [Label(text=lab,
                                          font=self.tick_label_font,
                                          color=self.tick_label_color) \
                                    for lab in labels]
            self._tick_label_bounding_boxes = [array(ticklabel.get_bounding_box(gc), float64) \
                                               for ticklabel in self.ticklabel_cache]
        else:
            scale = 'log' if isinstance(self.mapper, LogMapper) else 'linear'
            if self.small_haxis_style:
                tick_list = array([datalow, datahigh])
            else:
                tick_list = array(self.tick_generator.get_ticks(datalow, datahigh,
                                                                datalow, datahigh,
                                                                self.tick_interval,
                                                                use_endpoints=False,
                                                                scale=scale), float64)

        mapped_tick_positions = (array(self.mapper.map_screen(tick_list))-screenlow) / \
                                            (screenhigh-screenlow)
        self._tick_positions = around(array([self._axis_vector*tickpos + self._origin_point \
                                for tickpos in mapped_tick_positions]))
        self._tick_label_list = tick_list
        self._tick_label_positions = self._tick_positions
        return


    def _compute_labels(self, gc):
        """Generates the labels for tick marks.

        Waits for the cache to become invalid.
        """
        # tick labels are already computed
        if hasattr(self.tick_generator, "get_ticks_and_labels"):
            return

        formatter = self.tick_label_formatter
        def build_label(val):
            tickstring = formatter(val) if formatter is not None else str(val)
            return Label(text=tickstring,
                         font=self.tick_label_font,
                         color=self.tick_label_color,
                         rotate_angle=self.tick_label_rotate_angle,
                         margin=self.tick_label_margin)

        self.ticklabel_cache = [build_label(val) for val in self._tick_label_list]
        self._tick_label_bounding_boxes = [array(ticklabel.get_bounding_box(gc), float)
                                               for ticklabel in self.ticklabel_cache]
        return


    def _calculate_geometry(self):
        origin = self.origin
        screenhigh = self.mapper.high_pos
        screenlow = self.mapper.low_pos

        if self.orientation in ('top', 'bottom'):
            self._major_axis_size = self.bounds[0]
            self._minor_axis_size = self.bounds[1]
            self._major_axis = array([1., 0.])
            self._title_orientation = array([0.,1.])
            if self.orientation == 'top':
                self._origin_point = array(self.position)
                self._inside_vector = array([0.,-1.])
            else: #self.oriention == 'bottom'
                self._origin_point = array(self.position) + array([0., self.bounds[1]])
                self._inside_vector = array([0., 1.])
            if "right" in origin:
                screenlow, screenhigh = screenhigh, screenlow

        elif self.orientation in ('left', 'right'):
            self._major_axis_size = self.bounds[1]
            self._minor_axis_size = self.bounds[0]
            self._major_axis = array([0., 1.])
            self._title_orientation = array([-1., 0])
            if self.orientation == 'left':
                self._origin_point = array(self.position) + array([self.bounds[0], 0.])
                self._inside_vector = array([1., 0.])
            else: #self.orientation == 'right'
                self._origin_point = array(self.position)
                self._inside_vector = array([-1., 0.])
            if "top" in origin:
                screenlow, screenhigh = screenhigh, screenlow

        if self.ensure_ticks_bounded:
            self._origin_point -= self._inside_vector*self.tick_in

        self._end_axis_point = abs(screenhigh-screenlow)*self._major_axis + self._origin_point
        self._axis_vector = self._end_axis_point - self._origin_point
        # This is the vector that represents one unit of data space in terms of screen space.
        self._axis_pixel_vector = self._axis_vector/sqrt(dot(self._axis_vector,self._axis_vector))
        return


    def _calculate_geometry_overlay(self, overlay_component=None):
        if overlay_component is None:
            overlay_component = self
        component_origin = getattr(overlay_component, "origin", 'bottom left')

        screenhigh = self.mapper.high_pos
        screenlow = self.mapper.low_pos

        if self.orientation in ('top', 'bottom'):
            self._major_axis_size = overlay_component.bounds[0]
            self._minor_axis_size = overlay_component.bounds[1]
            self._major_axis = array([1., 0.])
            self._title_orientation = array([0.,1.])
            if self.orientation == 'top':
                self._origin_point = array([overlay_component.x, overlay_component.y2])
                self._inside_vector = array([0.0, -1.0])
            else:
                self._origin_point = array([overlay_component.x, overlay_component.y])
                self._inside_vector = array([0.0, 1.0])
            if "right" in component_origin:
                screenlow, screenhigh = screenhigh, screenlow

        elif self.orientation in ('left', 'right'):
            self._major_axis_size = overlay_component.bounds[1]
            self._minor_axis_size = overlay_component.bounds[0]
            self._major_axis = array([0., 1.])
            self._title_orientation = array([-1., 0])
            if self.orientation == 'left':
                self._origin_point = array([overlay_component.x, overlay_component.y])
                self._inside_vector = array([1.0, 0.0])
            else:
                self._origin_point = array([overlay_component.x2, overlay_component.y])
                self._inside_vector = array([-1.0, 0.0])
            if "top" in component_origin:
                screenlow, screenhigh = screenhigh, screenlow

        if self.ensure_ticks_bounded:
            self._origin_point -= self._inside_vector*self.tick_in

        self._end_axis_point = abs(screenhigh-screenlow)*self._major_axis + self._origin_point
        self._axis_vector = self._end_axis_point - self._origin_point
        # This is the vector that represents one unit of data space in terms of screen space.
        self._axis_pixel_vector = self._axis_vector/sqrt(dot(self._axis_vector,self._axis_vector))
        return


    #------------------------------------------------------------------------
    # Event handlers
    #------------------------------------------------------------------------

    def _bounds_changed(self, old, new):
        super(PlotAxis, self)._bounds_changed(old, new)
        self._layout_needed = True
        self._invalidate()

    def _bounds_items_changed(self, event):
        super(PlotAxis, self)._bounds_items_changed(event)
        self._layout_needed = True
        self._invalidate()

    def _mapper_changed(self, old, new):
        if old is not None:
            old.on_trait_change(self.mapper_updated, "updated", remove=True)
        if new is not None:
            new.on_trait_change(self.mapper_updated, "updated")
        self._invalidate()

    def mapper_updated(self):
        """
        Event handler that is bound to this axis's mapper's **updated** event
        """
        self._invalidate()

    def _position_changed(self, old, new):
        super(PlotAxis, self)._position_changed(old, new)
        self._cache_valid = False

    def _position_items_changed(self, event):
        super(PlotAxis, self)._position_items_changed(event)
        self._cache_valid = False

    def _position_changed_for_component(self):
        self._cache_valid = False

    def _position_items_changed_for_component(self):
        self._cache_valid = False

    def _bounds_changed_for_component(self):
        self._cache_valid = False
        self._layout_needed = True

    def _bounds_items_changed_for_component(self):
        self._cache_valid = False
        self._layout_needed = True

    def _origin_changed_for_component(self):
        self._invalidate()

    def _updated_fired(self):
        """If the axis bounds changed, redraw."""
        self._cache_valid = False
        return

    def _invalidate(self):
        self._cache_valid = False
        self.invalidate_draw()
        if self.component:
            self.component.invalidate_draw()
        return

    def _component_changed(self):
        if self.mapper is not None:
            # If there is a mapper set, just leave it be.
            return

        # Try to pick the most appropriate mapper for our orientation
        # and what information we can glean from our component.
        attrmap = { "left": ("ymapper", "y_mapper", "value_mapper"),
                    "bottom": ("xmapper", "x_mapper", "index_mapper"), }
        attrmap["right"] = attrmap["left"]
        attrmap["top"] = attrmap["bottom"]

        component = self.component
        attr1, attr2, attr3 = attrmap[self.orientation]
        for attr in attrmap[self.orientation]:
            if hasattr(component, attr):
                self.mapper = getattr(component, attr)
                break

        # Keep our origin in sync with the component
        self.origin = getattr(component, 'origin', 'bottom left')
        return


    #------------------------------------------------------------------------
    # The following event handlers just invalidate our previously computed
    # Label instances and backbuffer if any of our visual attributes change.
    # TODO: refactor this stuff and the caching of contained objects (e.g. Label)
    #------------------------------------------------------------------------

    def _title_changed(self):
        self.invalidate_draw()
        if self.component:
            self.component.invalidate_draw()
        return

    def _anytrait_changed(self, name, old, new):
        """ For every trait that defines a visual attribute
            we just call _invalidate() when a change is made.
        """
        invalidate_traits = [
            'title_font',
            'title_spacing',
            'title_color',
            'title_angle',
            'tick_weight',
            'tick_color',
            'tick_label_font',
            'tick_label_color',
            'tick_label_rotate_angle',
            'tick_label_alignment',
            'tick_label_margin',
            'tick_label_offset',
            'tick_label_position',
            'tick_label_formatter',
            'tick_in',
            'tick_out',
            'tick_visible',
            'tick_interval',
            'tick_generator',
            'orientation',
            'origin',
            'axis_line_visible',
            'axis_line_color',
            'axis_line_weight',
            'axis_line_style',
            'small_haxis_style',
            'ensure_labels_bounded',
            'ensure_ticks_bounded',
        ]
        if name in invalidate_traits:
            self._invalidate()

    # ------------------------------------------------------------------------
    # Initialization-related methods
    # ------------------------------------------------------------------------

    def _title_angle_default(self):
        if self.orientation == 'left':
            return 90.0
        if self.orientation == 'right':
            return 270.0
        # Then self.orientation in {'top', 'bottom'}
        return 0.0

    #------------------------------------------------------------------------
    # Persistence-related methods
    #------------------------------------------------------------------------

    def __getstate__(self):
        dont_pickle = [
            '_tick_list',
            '_tick_positions',
            '_tick_label_list',
            '_tick_label_positions',
            '_tick_label_bounding_boxes',
            '_major_axis_size',
            '_minor_axis_size',
            '_major_axis',
            '_title_orientation',
            '_title_angle',
            '_origin_point',
            '_inside_vector',
            '_axis_vector',
            '_axis_pixel_vector',
            '_end_axis_point',
            '_ticklabel_cache',
            '_cache_valid'
           ]

        state = super(PlotAxis,self).__getstate__()
        for key in dont_pickle:
            if key in state:
                del state[key]

        return state

    def __setstate__(self, state):
        super(PlotAxis,self).__setstate__(state)
        self._mapper_changed(None, self.mapper)
        self._reset_cache()
        self._cache_valid = False
        return
Exemplo n.º 12
0
class RangeSelection(AbstractController):
    """ Selects a range along the index or value axis.

    The user right-click-drags to select a region, which stays selected until
    the user left-clicks to deselect.
    """

    #: The axis to which this tool is perpendicular.
    axis = Enum("index", "value")

    #: The selected region, expressed as a tuple in data space.  This updates
    #: and fires change-events as the user is dragging.
    selection = Property

    selection_mode = Enum("set", "append")

    #: This event is fired whenever the user completes the selection, or when a
    #: finalized selection gets modified.  The value of the event is the data
    #: space range.
    selection_completed = Event

    #: The name of the metadata on the datasource that we will write
    #: self.selection to
    metadata_name = Str("selections")

    #: Either "set" or "append", depending on whether self.append_key was
    #: held down
    selection_mode_metadata_name = Str("selection_mode")

    #: The name of the metadata on the datasource that we will set to a numpy
    #: boolean array for masking the datasource's data
    mask_metadata_name = Str("selection_masks")

    #: The possible event states of this selection tool (overrides
    #: enable.Interactor).
    #:
    #: normal:
    #:     Nothing has been selected, and the user is not dragging the mouse.
    #: selecting:
    #:     The user is dragging the mouse and actively changing the
    #:     selection region; resizing of an existing selection also
    #:     uses this mode.
    #: selected:
    #:     The user has released the mouse and a selection has been
    #:     finalized.  The selection remains until the user left-clicks
    #:     or self.deselect() is called.
    #: moving:
    #:   The user moving (not resizing) the selection range.
    event_state = Enum("normal", "selecting", "selected", "moving")

    #------------------------------------------------------------------------
    # Traits for overriding default object relationships
    #
    # By default, the RangeSelection assumes that self.component is a plot
    # and looks for the mapper and the axis_index on it.  If this is not the
    # case, then any (or all) three of these can be overriden by directly
    # assigning values to them.  To unset them and have them revert to default
    # behavior, assign "None" to them.
    #------------------------------------------------------------------------

    #: The plot associated with this tool By default, this is just
    #: self.component.
    plot = Property

    #: The mapper for associated with this tool. By default, this is the mapper
    #: on **plot** that corresponds to **axis**.
    mapper = Property

    #: The index to use for **axis**. By default, this is self.plot.orientation,
    #: but it can be overriden and set to 0 or 1.
    axis_index = Property

    #: List of listeners that listen to selection events.
    listeners = List

    #------------------------------------------------------------------------
    # Configuring interaction control
    #------------------------------------------------------------------------

    #: Can the user resize the selection once it has been drawn?
    enable_resize = Bool(True)

    #: The pixel distance between the mouse event and a selection endpoint at
    #: which the user action will be construed as a resize operation.
    resize_margin = Int(7)

    #: Allow the left button begin a selection?
    left_button_selects = Bool(False)

    #: Disable all left-mouse button interactions?
    disable_left_mouse = Bool(False)

    #: Allow the tool to be put into the deselected state via mouse clicks
    allow_deselection = Bool(True)

    #: The minimum span, in pixels, of a selection region.  Any attempt to
    #: select a region smaller than this will be treated as a deselection.
    minimum_selection = Int(5)

    #: The key which, if held down while the mouse is being dragged, will
    #: indicate that the selection should be appended to an existing selection
    #: as opposed to overwriting it.
    append_key = Instance(KeySpec, args=(None, "control"))

    #------------------------------------------------------------------------
    # Private traits
    #------------------------------------------------------------------------

    # The value of the override plot to use, if any.  If None, then uses
    # self.component.
    _plot = Trait(None, Any)

    # The value of the override mapper to use, if any.  If None, then uses the
    # mapper on self.component.
    _mapper = Trait(None, Any)

    # Shadow trait for the **axis_index** property.
    _axis_index = Trait(None, None, Int)

    # The data space start and end coordinates of the selected region,
    # expressed as an array.
    _selection = ArrayOrNone()

    # The selection in mask form.
    _selection_mask = Array

    # The end of the selection that is being actively modified by the mouse.
    _drag_edge = Enum("high", "low")

    #------------------------------------------------------------------------
    # These record the mouse position when the user is moving (not resizing)
    # the selection
    #------------------------------------------------------------------------

    # The position of the initial user click for moving the selection.
    _down_point = Array  # (x,y)

    # The data space coordinates of **_down_point**.
    _down_data_coord = Float

    # The original selection when the mouse went down to move the selection.
    _original_selection = Any

    #------------------------------------------------------------------------
    # Public methods
    #------------------------------------------------------------------------

    def deselect(self, event=None):
        """ Deselects the highlighted region.

        This method essentially resets the tool. It takes the event causing the
        deselection as an optional argument.
        """
        self.selection = None
        self.selection_completed = None
        self.event_state = "normal"
        self.component.request_redraw()
        if event:
            event.window.set_pointer("arrow")
            event.handled = True
        return

    #------------------------------------------------------------------------
    # Event handlers for the "selected" event state
    #------------------------------------------------------------------------

    def selected_left_down(self, event):
        """ Handles the left mouse button being pressed when the tool is in
        the 'selected' state.

        If the user is allowed to resize the selection, and the event occurred
        within the resize margin of an endpoint, then the tool switches to the
        'selecting' state so that the user can resize the selection.

        If the event is within the bounds of the selection region, then the
        tool switches to the 'moving' states.

        Otherwise, the selection becomes deselected.
        """
        if self.disable_left_mouse:
            return

        screen_bounds = self._get_selection_screencoords()
        if screen_bounds is None:
            self.deselect(event)
            return
        low = min(screen_bounds)
        high = max(screen_bounds)
        tmp = (event.x, event.y)
        ndx = self.axis_index
        mouse_coord = tmp[ndx]

        if self.enable_resize:
            if (abs(mouse_coord - high) <= self.resize_margin) or \
                            (abs(mouse_coord - low) <= self.resize_margin):
                return self.selected_right_down(event)

        if low <= tmp[ndx] <= high:
            self.event_state = "moving"
            self._down_point = array([event.x, event.y])
            self._down_data_coord = \
                self.mapper.map_data(self._down_point)[ndx]
            self._original_selection = array(self.selection)
        elif self.allow_deselection:
            self.deselect(event)
        else:
            # Treat this as a combination deselect + left down
            self.deselect(event)
            self.normal_left_down(event)
        event.handled = True
        return

    def selected_right_down(self, event):
        """ Handles the right mouse button being pressed when the tool is in
        the 'selected' state.

        If the user is allowed to resize the selection, and the event occurred
        within the resize margin of an endpoint, then the tool switches to the
        'selecting' state so that the user can resize the selection.

        Otherwise, the selection becomes deselected, and a new selection is
        started..
        """
        if self.enable_resize:
            coords = self._get_selection_screencoords()
            if coords is not None:
                start, end = coords
                tmp = (event.x, event.y)
                ndx = self.axis_index
                mouse_coord = tmp[ndx]
                # We have to do a little swapping; the "end" point
                # is always what gets updated, so if the user
                # clicked on the starting point, we have to reverse
                # the sense of the selection.
                if abs(mouse_coord - end) <= self.resize_margin:
                    self.event_state = "selecting"
                    self._drag_edge = "high"
                    self.selecting_mouse_move(event)
                elif abs(mouse_coord - start) <= self.resize_margin:
                    self.event_state = "selecting"
                    self._drag_edge = "low"
                    self.selecting_mouse_move(event)
                #elif self.allow_deselection:
                #    self.deselect(event)
                else:
                    # Treat this as a combination deselect + right down
                    self.deselect(event)
                    self.normal_right_down(event)
        else:
            # Treat this as a combination deselect + right down
            self.deselect(event)
            self.normal_right_down(event)
        event.handled = True
        return

    def selected_mouse_move(self, event):
        """ Handles the mouse moving when the tool is in the 'selected' srate.

        If the user is allowed to resize the selection, and the event
        occurred within the resize margin of an endpoint, then the cursor
        changes to indicate that the selection could be resized.

        Otherwise, the cursor is set to an arrow.
        """
        if self.enable_resize:
            # Change the mouse cursor when the user moves within the
            # resize margin
            coords = self._get_selection_screencoords()
            if coords is not None:
                start, end = coords
                tmp = (event.x, event.y)
                ndx = self.axis_index
                mouse_coord = tmp[ndx]
                if abs(mouse_coord - end) <= self.resize_margin or \
                        abs(mouse_coord - start) <= self.resize_margin:
                    self._set_sizing_cursor(event)
                    return
        event.window.set_pointer("arrow")
        event.handled = True
        return

    def selected_mouse_leave(self, event):
        """ Handles the mouse leaving the plot when the tool is in the
        'selected' state.

        Sets the cursor to an arrow.
        """
        event.window.set_pointer("arrow")
        return

    #------------------------------------------------------------------------
    # Event handlers for the "moving" event state
    #------------------------------------------------------------------------

    def moving_left_up(self, event):
        """ Handles the left mouse button coming up when the tool is in the
        'moving' state.

        Switches the tool to the 'selected' state.
        """
        if self.disable_left_mouse:
            return

        self.event_state = "selected"
        self.selection_completed = self.selection
        self._down_point = []
        event.handled = True
        return

    def moving_mouse_move(self, event):
        """ Handles the mouse moving when the tool is in the 'moving' state.

        Moves the selection range by an amount corresponding to the amount
        that the mouse has moved since its button was pressed. If the new
        selection range overlaps the endpoints of the data, it is truncated to
        that endpoint.
        """
        cur_point = array([event.x, event.y])
        cur_data_point = self.mapper.map_data(cur_point)[self.axis_index]
        original_selection = self._original_selection
        new_selection = original_selection + (cur_data_point -
                                              self._down_data_coord)
        selection_data_width = original_selection[1] - original_selection[0]

        range = self.mapper.range
        if min(new_selection) < range.low:
            new_selection = (range.low, range.low + selection_data_width)
        elif max(new_selection) > range.high:
            new_selection = (range.high - selection_data_width, range.high)

        self.selection = new_selection
        self.selection_completed = new_selection
        self.component.request_redraw()
        event.handled = True
        return

    def moving_mouse_leave(self, event):
        """ Handles the mouse leaving the plot while the tool is in the
        'moving' state.

        If the mouse was within the selection region when it left, the method
        does nothing.

        If the mouse was outside the selection region whe it left, the event is
        treated as moving the selection to the minimum or maximum.
        """
        axis_index = self.axis_index
        low = self.plot.position[axis_index]
        high = low + self.plot.bounds[axis_index] - 1

        pos = self._get_axis_coord(event)
        if pos >= low and pos <= high:
            # the mouse left but was within the mapping range, so don't do
            # anything
            return
        else:
            # the mouse left and exceeds the mapping range, so we need to slam
            # the selection all the way to the minimum or the maximum
            self.moving_mouse_move(event)
        return

    def moving_mouse_enter(self, event):
        if not event.left_down:
            return self.moving_left_up(event)
        return

    #------------------------------------------------------------------------
    # Event handlers for the "normal" event state
    #------------------------------------------------------------------------

    def normal_left_down(self, event):
        """ Handles the left mouse button being pressed when the tool is in
        the 'normal' state.

        If the tool allows the left mouse button to start a selection, then
        it does so.
        """
        if self.left_button_selects:
            return self.normal_right_down(event)

    def normal_right_down(self, event):
        """ Handles the right mouse button being pressed when the tool is in
        the 'normal' state.

        Puts the tool into 'selecting' mode, changes the cursor to show that it
        is selecting, and starts defining the selection.

        """
        pos = self._get_axis_coord(event)
        mapped_pos = self.mapper.map_data(pos)
        self.selection = (mapped_pos, mapped_pos)
        self._set_sizing_cursor(event)
        self._down_point = array([event.x, event.y])
        self.event_state = "selecting"
        if self.append_key is not None and self.append_key.match(event):
            self.selection_mode = "append"
        else:
            self.selection_mode = "set"
        self.selecting_mouse_move(event)
        return

    #------------------------------------------------------------------------
    # Event handlers for the "selecting" event state
    #------------------------------------------------------------------------

    def selecting_mouse_move(self, event):
        """ Handles the mouse being moved when the tool is in the 'selecting'
        state.

        Expands the selection range at the appropriate end, based on the new
        mouse position.
        """
        if self.selection is not None:
            axis_index = self.axis_index
            low = self.plot.position[axis_index]
            high = low + self.plot.bounds[axis_index] - 1
            tmp = self._get_axis_coord(event)
            if tmp >= low and tmp <= high:
                new_edge = self.mapper.map_data(self._get_axis_coord(event))
                if self._drag_edge == "high":
                    low_val = self.selection[0]
                    if new_edge >= low_val:
                        self.selection = (low_val, new_edge)
                    else:
                        self.selection = (new_edge, low_val)
                        self._drag_edge = "low"
                else:
                    high_val = self.selection[1]
                    if new_edge <= high_val:
                        self.selection = (new_edge, high_val)
                    else:
                        self.selection = (high_val, new_edge)
                        self._drag_edge = "high"

                self.component.request_redraw()
            event.handled = True
        return

    def selecting_button_up(self, event):
        # Check to see if the selection region is bigger than the minimum
        event.window.set_pointer("arrow")

        end = self._get_axis_coord(event)

        if len(self._down_point) == 0:
            cancel_selection = False
        else:
            start = self._down_point[self.axis_index]
            self._down_point = []
            cancel_selection = self.minimum_selection > abs(start - end)

        if cancel_selection:
            self.deselect(event)
            event.handled = True
        else:
            self.event_state = "selected"

            # Fire the "completed" event
            self.selection_completed = self.selection
            event.handled = True
        return

    def selecting_right_up(self, event):
        """ Handles the right mouse button coming up when the tool is in the
        'selecting' state.

        Switches the tool to the 'selected' state and completes the selection.
        """
        self.selecting_button_up(event)

    def selecting_left_up(self, event):
        """ Handles the left mouse button coming up when the tool is in the
        'selecting' state.

        Switches the tool to the 'selected' state.
        """
        if self.disable_left_mouse:
            return
        self.selecting_button_up(event)

    def selecting_mouse_leave(self, event):
        """ Handles the mouse leaving the plot when the tool is in the
        'selecting' state.

        Determines whether the event's position is outside the component's
        bounds, and if so, clips the selection. Sets the cursor to an arrow.
        """
        axis_index = self.axis_index
        low = self.plot.position[axis_index]
        high = low + self.plot.bounds[axis_index] - 1

        old_selection = self.selection
        selection_low = old_selection[0]
        selection_high = old_selection[1]

        pos = self._get_axis_coord(event)
        if pos >= high:
            # clip to the boundary appropriate for the mapper's orientation.
            if self.mapper.sign == 1:
                selection_high = self.mapper.map_data(high)
            else:
                selection_high = self.mapper.map_data(low)
        elif pos <= low:
            if self.mapper.sign == 1:
                selection_low = self.mapper.map_data(low)
            else:
                selection_low = self.mapper.map_data(high)

        self.selection = (selection_low, selection_high)
        event.window.set_pointer("arrow")
        self.component.request_redraw()
        return

    def selecting_mouse_enter(self, event):
        """ Handles the mouse entering the plot when the tool is in the
        'selecting' state.

        If the mouse does not have the right mouse button down, this event
        is treated as if the right mouse button was released. Otherwise,
        the method sets the cursor to show that it is selecting.
        """
        # If we were in the "selecting" state when the mouse left, and
        # the mouse has entered without a button being down,
        # then treat this like we got a button up event.
        if not (event.right_down or event.left_down):
            return self.selecting_button_up(event)
        else:
            self._set_sizing_cursor(event)
        return

    #------------------------------------------------------------------------
    # Property getter/setters
    #------------------------------------------------------------------------

    def _get_plot(self):
        if self._plot is not None:
            return self._plot
        else:
            return self.component

    def _set_plot(self, val):
        self._plot = val
        return

    def _get_mapper(self):
        if self._mapper is not None:
            return self._mapper
        else:
            return getattr(self.plot, self.axis + "_mapper")

    def _set_mapper(self, new_mapper):
        self._mapper = new_mapper
        return

    def _get_axis_index(self):
        if self._axis_index is None:
            return self._determine_axis()
        else:
            return self._axis_index

    def _set_axis_index(self, val):
        self._axis_index = val
        return

    def _get_selection(self):
        selection = getattr(self.plot, self.axis).metadata[self.metadata_name]
        return selection

    def _set_selection(self, val):
        oldval = self._selection
        self._selection = val

        datasource = getattr(self.plot, self.axis, None)

        if datasource is not None:

            mdname = self.metadata_name

            # Set the selection range on the datasource
            datasource.metadata[mdname] = val
            datasource.metadata_changed = {mdname: val}

            # Set the selection mask on the datasource
            selection_masks = \
                datasource.metadata.setdefault(self.mask_metadata_name, [])
            for index in range(len(selection_masks)):
                if id(selection_masks[index]) == id(self._selection_mask):
                    del selection_masks[index]
                    break

            # Set the selection mode on the datasource
            datasource.metadata[self.selection_mode_metadata_name] = \
                      self.selection_mode

            if val is not None:
                low, high = val
                data_pts = datasource.get_data()
                new_mask = (data_pts >= low) & (data_pts <= high)
                selection_masks.append(new_mask)
                self._selection_mask = new_mask
            datasource.metadata_changed = {self.mask_metadata_name: val}

        self.trait_property_changed("selection", oldval, val)

        for l in self.listeners:
            if hasattr(l, "set_value_selection"):
                l.set_value_selection(val)

        return

    #------------------------------------------------------------------------
    # Private methods
    #------------------------------------------------------------------------

    def _get_selection_screencoords(self):
        """ Returns a tuple of (x1, x2) screen space coordinates of the start
        and end selection points.

        If there is no current selection, then it returns None.
        """
        selection = self.selection
        if selection is not None and len(selection) == 2:
            return self.mapper.map_screen(array(selection))
        else:
            return None

    def _set_sizing_cursor(self, event):
        """ Sets the correct cursor shape on the window of the event, given the
        tool's orientation and axis.
        """
        if self.axis_index == 0:
            # horizontal range selection, so use left/right arrow
            event.window.set_pointer("size left")
        else:
            # vertical range selection, so use up/down arrow
            event.window.set_pointer("size top")
        return

    def _get_axis_coord(self, event, axis="index"):
        """ Returns the coordinate of the event along the axis of interest
        to this tool (or along the orthogonal axis, if axis="value").
        """
        event_pos = (event.x, event.y)
        if axis == "index":
            return event_pos[self.axis_index]
        else:
            return event_pos[1 - self.axis_index]

    def _determine_axis(self):
        """ Determines whether the index of the coordinate along this tool's
        axis of interest is the first or second element of an (x,y) coordinate
        tuple.

        This method is only called if self._axis_index hasn't been set (or is
        None).
        """
        if self.axis == "index":
            if self.plot.orientation == "h":
                return 0
            else:
                return 1
        else:  # self.axis == "value"
            if self.plot.orientation == "h":
                return 1
            else:
                return 0

    def __mapper_changed(self):
        self.deselect()
        return

    def _axis_changed(self, old, new):
        if old is not None:
            self.plot.on_trait_change(self.__mapper_changed,
                                      old + "_mapper",
                                      remove=True)
        if new is not None:
            self.plot.on_trait_change(self.__mapper_changed,
                                      old + "_mapper",
                                      remove=True)
        return
Exemplo n.º 13
0
class DataLabel(ToolTip):
    """ A label on a point in data space.

    Optionally, an arrow is drawn to the point.
    """

    # The symbol to use if **marker** is set to "custom". This attribute must
    # be a compiled path for the given Kiva context.
    custom_symbol = Any

    # The point in data space where this label should anchor itself.
    data_point = ArrayOrNone()

    # The location of the data label relative to the data point.
    label_position = LabelPositionTrait

    # The format string that determines the label's text.  This string is
    # formatted using a dict containing the keys 'x' and 'y', corresponding to
    # data space values.
    label_format = Str("(%(x)f, %(y)f)")

    # The text to show on the label, or above the coordinates for the label, if
    # show_label_coords is True
    label_text = Str

    # Flag whether to show coordinates with the label or not.
    show_label_coords = Bool(True)

    # Does the label clip itself against the main plot area?  If not, then
    # the label draws into the padding area (where axes typically reside).
    clip_to_plot = Bool(True)

    # The center x position (average of x and x2)
    xmid = Property(Float, depends_on=['x', 'x2'])

    # The center y position (average of y and y2)
    ymid = Property(Float, depends_on=['y', 'y2'])

    # 'box' is a simple rectangular box, with an arrow that is a single line
    # with an arrowhead at the data point.
    # 'bubble' can be given rounded corners (by setting `corner_radius`), and
    # the 'arrow' is a thin triangular wedge with its point at the data point.
    # When label_style is 'bubble', the following traits are ignored:
    #    arrow_size, arrow_color, arrow_root, and arrow_max_length.
    label_style = Enum('box', 'bubble')

    #----------------------------------------------------------------------
    # Marker traits
    #----------------------------------------------------------------------

    # Mark the point on the data that this label refers to?
    marker_visible = Bool(True)

    # The type of marker to use.  This is a mapped trait using strings as the
    # keys.
    marker = MarkerTrait

    # The pixel size of the marker (doesn't include the thickness of the
    # outline).
    marker_size = Int(4)

    # The thickness, in pixels, of the outline to draw around the marker.
    # If this is 0, no outline will be drawn.
    marker_line_width = Float(1.0)

    # The color of the inside of the marker.
    marker_color = ColorTrait("red")

    # The color out of the border drawn around the marker.
    marker_line_color = ColorTrait("black")

    #----------------------------------------------------------------------
    # Arrow traits
    #----------------------------------------------------------------------

    # Draw an arrow from the label to the data point?  Only
    # used if **data_point** is not None.
    arrow_visible = Bool(True)  # FIXME: replace with some sort of ArrowStyle

    # The length of the arrowhead, in screen points (e.g., pixels).
    arrow_size = Float(10)

    # The color of the arrow.
    arrow_color = ColorTrait("black")

    # The position of the base of the arrow on the label.  If this
    # is 'auto', then the label uses **label_position**.  Otherwise, it
    # treats the label as if it were at the label position indicated by
    # this attribute.
    arrow_root = Trait("auto", "auto", "top left", "top right", "bottom left",
                       "bottom right", "top center", "bottom center",
                       "left center", "right center")

    # The minimum length of the arrow before it will be drawn.  By default,
    # the arrow will be drawn regardless of how short it is.
    arrow_min_length = Float(0)

    # The maximum length of the arrow before it will be drawn.  By default,
    # the arrow will be drawn regardless of how long it is.
    arrow_max_length = Float(inf)

    #----------------------------------------------------------------------
    # Bubble traits
    #----------------------------------------------------------------------

    # The radius (in screen coordinates) of the curved corners of the "bubble".
    corner_radius = Float(10)

    #-------------------------------------------------------------------------
    # Private traits
    #-------------------------------------------------------------------------

    # Tuple (sx, sy) of the mapped screen coordinates of **data_point**.
    _screen_coords = Any

    _cached_arrow = Any

    # When **arrow_root** is 'auto', this determines the location on the data
    # label from which the arrow is drawn, based on the position of the label
    # relative to its data point.
    _position_root_map = {
        "top left": "bottom right",
        "top right": "bottom left",
        "bottom left": "top right",
        "bottom right": "top left",
        "top center": "bottom center",
        "bottom center": "top center",
        "left center": "right center",
        "right center": "left center"
    }

    _root_positions = {
        "bottom right": ("x2", "y"),
        "bottom left": ("x", "y"),
        "top right": ("x2", "y2"),
        "top left": ("x", "y2"),
        "top center": ("xmid", "y2"),
        "bottom center": ("xmid", "y"),
        "left center": ("x", "ymid"),
        "right center": ("x2", "ymid"),
    }

    def overlay(self, component, gc, view_bounds=None, mode="normal"):
        """ Draws the tooltip overlaid on another component.

        Overrides and extends ToolTip.overlay()
        """
        if self.clip_to_plot:
            gc.save_state()
            c = component
            gc.clip_to_rect(c.x, c.y, c.width, c.height)

        self.do_layout()

        if self.label_style == 'box':
            self._render_box(component, gc, view_bounds=view_bounds, mode=mode)
        else:
            self._render_bubble(component,
                                gc,
                                view_bounds=view_bounds,
                                mode=mode)

        # draw the marker
        if self.marker_visible:
            render_markers(gc, [self._screen_coords], self.marker,
                           self.marker_size, self.marker_color_,
                           self.marker_line_width, self.marker_line_color_,
                           self.custom_symbol)

        if self.clip_to_plot:
            gc.restore_state()

    def _render_box(self, component, gc, view_bounds=None, mode='normal'):
        # draw the arrow if necessary
        if self.arrow_visible:
            if self._cached_arrow is None:
                if self.arrow_root in self._root_positions:
                    ox, oy = self._root_positions[self.arrow_root]
                else:
                    if self.arrow_root == "auto":
                        arrow_root = self.label_position
                    else:
                        arrow_root = self.arrow_root
                    pos = self._position_root_map.get(arrow_root, "DUMMY")
                    ox, oy = self._root_positions.get(
                        pos,
                        (self.x + self.width / 2, self.y + self.height / 2))

                if type(ox) == str:
                    ox = getattr(self, ox)
                    oy = getattr(self, oy)
                self._cached_arrow = draw_arrow(gc, (ox, oy),
                                                self._screen_coords,
                                                self.arrow_color_,
                                                arrowhead_size=self.arrow_size,
                                                offset1=3,
                                                offset2=self.marker_size + 3,
                                                minlen=self.arrow_min_length,
                                                maxlen=self.arrow_max_length)
            else:
                draw_arrow(gc,
                           None,
                           None,
                           self.arrow_color_,
                           arrow=self._cached_arrow,
                           minlen=self.arrow_min_length,
                           maxlen=self.arrow_max_length)

        # layout and render the label itself
        ToolTip.overlay(self, component, gc, view_bounds, mode)

    def _render_bubble(self, component, gc, view_bounds=None, mode='normal'):
        """ Render the bubble label in the graphics context. """
        # (px, py) is the data point in screen space.
        px, py = self._screen_coords

        # (x, y) is the lower left corner of the label.
        x = self.x
        y = self.y
        # (x2, y2) is the upper right corner of the label.
        x2 = self.x2
        y2 = self.y2
        # r is the corner radius.
        r = self.corner_radius

        if self.arrow_visible:
            # FIXME: Make 'gap_width' a configurable trait (and give it a
            #        better name).
            max_gap_width = 10
            gap_width = min(max_gap_width, abs(x2 - x - 2 * r),
                            abs(y2 - y - 2 * r))
            region = find_region(px, py, x, y, x2, y2)

            # Figure out where the "arrow" connects to the "bubble".
            if region == 'left' or region == 'right':
                gap_start = py - gap_width / 2
                if gap_start < y + r:
                    gap_start = y + r
                elif gap_start > y2 - r - gap_width:
                    gap_start = y2 - r - gap_width
                by = gap_start + 0.5 * gap_width
                if region == 'left':
                    bx = x
                else:
                    bx = x2
            else:
                gap_start = px - gap_width / 2
                if gap_start < x + r:
                    gap_start = x + r
                elif gap_start > x2 - r - gap_width:
                    gap_start = x2 - r - gap_width
                bx = gap_start + 0.5 * gap_width
                if region == 'top':
                    by = y2
                else:
                    by = y
            arrow_len = sqrt((px - bx)**2 + (py - by)**2)

        arrow_visible = (self.arrow_visible
                         and (arrow_len >= self.arrow_min_length))

        with gc:
            if self.border_visible:
                gc.set_line_width(self.border_width)
                gc.set_stroke_color(self.border_color_)
            else:
                gc.set_line_width(0)
                gc.set_stroke_color((0, 0, 0, 0))
            gc.set_fill_color(self.bgcolor_)

            # Start at the lower left, on the left edge where the curved
            # part of the box ends.
            gc.move_to(x, y + r)

            # Draw the left side and the upper left curved corner.
            if arrow_visible and region == 'left':
                gc.line_to(x, gap_start)
                gc.line_to(px, py)
                gc.line_to(x, gap_start + gap_width)
            gc.arc_to(x, y2, x + r, y2, r)

            # Draw the top and the upper right curved corner.
            if arrow_visible and region == 'top':
                gc.line_to(gap_start, y2)
                gc.line_to(px, py)
                gc.line_to(gap_start + gap_width, y2)
            gc.arc_to(x2, y2, x2, y2 - r, r)

            # Draw the right side and the lower right curved corner.
            if arrow_visible and region == 'right':
                gc.line_to(x2, gap_start + gap_width)
                gc.line_to(px, py)
                gc.line_to(x2, gap_start)
            gc.arc_to(x2, y, x2 - r, y, r)

            # Draw the bottom and the lower left curved corner.
            if arrow_visible and region == 'bottom':
                gc.line_to(gap_start + gap_width, y)
                gc.line_to(px, py)
                gc.line_to(gap_start, y)
            gc.arc_to(x, y, x, y + r, r)

            # Finish the "bubble".
            gc.draw_path()

            self._draw_overlay(gc)

    def _do_layout(self, size=None):
        """Computes the size and position of the label and arrow.

        Overrides and extends ToolTip._do_layout()
        """
        if not self.component or not hasattr(self.component, "map_screen"):
            return

        # Call the parent class layout.  This computes all the label
        ToolTip._do_layout(self)

        self._screen_coords = self.component.map_screen([self.data_point])[0]
        sx, sy = self._screen_coords

        if isinstance(self.label_position, str):
            orientation = self.label_position
            if ("left" in orientation) or ("right" in orientation):
                if " " not in orientation:
                    self.y = sy - self.height / 2
                if "left" in orientation:
                    self.outer_x = sx - self.outer_width - 1
                elif "right" in orientation:
                    self.outer_x = sx
            if ("top" in orientation) or ("bottom" in orientation):
                if " " not in orientation:
                    self.x = sx - self.width / 2
                if "bottom" in orientation:
                    self.outer_y = sy - self.outer_height - 1
                elif "top" in orientation:
                    self.outer_y = sy
            if "center" in orientation:
                if " " not in orientation:
                    self.x = sx - (self.width / 2)
                    self.y = sy - (self.height / 2)
                else:
                    self.x = sx - (self.outer_width / 2) - 1
                    self.y = sy - (self.outer_height / 2) - 1
        else:
            self.x = sx + self.label_position[0]
            self.y = sy + self.label_position[1]

        self._cached_arrow = None
        return

    def _data_point_changed(self, old, new):
        if new is not None:
            self._create_new_labels()

    def _label_format_changed(self, old, new):
        self._create_new_labels()

    def _label_text_changed(self, old, new):
        self._create_new_labels()

    def _show_label_coords_changed(self, old, new):
        self._create_new_labels()

    def _create_new_labels(self):
        pt = self.data_point
        if pt is not None:
            if self.show_label_coords:
                self.lines = [
                    self.label_text, self.label_format % {
                        "x": pt[0],
                        "y": pt[1]
                    }
                ]
            else:
                self.lines = [self.label_text]

    def _component_changed(self, old, new):
        for comp, attach in ((old, False), (new, True)):
            if comp is not None:
                if hasattr(comp, 'index_mapper'):
                    self._modify_mapper_listeners(comp.index_mapper,
                                                  attach=attach)
                if hasattr(comp, 'value_mapper'):
                    self._modify_mapper_listeners(comp.value_mapper,
                                                  attach=attach)
        return

    def _modify_mapper_listeners(self, mapper, attach=True):
        if mapper is not None:
            mapper.on_trait_change(self._handle_mapper,
                                   'updated',
                                   remove=not attach)
        return

    def _handle_mapper(self):
        # This gets fired whenever a mapper on our plot fires its
        # 'updated' event.
        self._layout_needed = True

    @on_trait_change("arrow_size,arrow_root,arrow_min_length," +
                     "arrow_max_length")
    def _invalidate_arrow(self):
        self._cached_arrow = None
        self._layout_needed = True

    @on_trait_change("label_position,position,position_items,bounds," +
                     "bounds_items")
    def _invalidate_layout(self):
        self._layout_needed = True

    def _get_xmid(self):
        return 0.5 * (self.x + self.x2)

    def _get_ymid(self):
        return 0.5 * (self.y + self.y2)
Exemplo n.º 14
0
    array, argsort, concatenate, cos, diff, dot, dtype, empty, float32,
    isfinite, nonzero, pi, searchsorted, seterr, sin, int8
)

# Enthought library imports
from traits.api import Enum, ArrayOrNone

delta = {'ascending': 1, 'descending': -1, 'flat': 0}

rgba_dtype = dtype([('r', float32), ('g', float32), ('b', float32), ('a', float32)])
point_dtype = dtype([('x', float), ('y', float)])

# Dimensions

# A single array of numbers.
NumericalSequenceTrait = ArrayOrNone()

# A sequence of pairs of numbers, i.e., an Nx2 array.
PointTrait = ArrayOrNone(shape=(None, 2))

# An NxM array of numbers or NxMxRGB(A) array of colors.
ImageTrait = ArrayOrNone()

# An 3D array of numbers of shape (Nx, Ny, Nz)
CubeTrait = ArrayOrNone(shape=(None, None, None))


# This enumeration lists the fundamental mathematical coordinate types that
# Chaco supports.
DimensionTrait = Enum("scalar", "point", "image", "cube")
Exemplo n.º 15
0
class PointObject(Object):
    """Represent a group of individual points in a mayavi scene."""

    label = Bool(False)
    label_scale = Float(0.01)
    projectable = Bool(False)  # set based on type of points
    orientable = Property(depends_on=['nearest'])
    text3d = List
    point_scale = Float(10, label='Point Scale')

    # projection onto a surface
    nearest = Instance(_DistanceQuery)
    check_inside = Instance(_CheckInside)
    project_to_trans = ArrayOrNone(float, shape=(4, 4))
    project_to_surface = Bool(False,
                              label='Project',
                              desc='project points '
                              'onto the surface')
    orient_to_surface = Bool(False,
                             label='Orient',
                             desc='orient points '
                             'toward the surface')
    scale_by_distance = Bool(False,
                             label='Dist.',
                             desc='scale points by '
                             'distance from the surface')
    mark_inside = Bool(False,
                       label='Mark',
                       desc='mark points inside the '
                       'surface in a different color')
    inside_color = RGBColor((0., 0., 0.))

    glyph = Instance(Glyph)
    resolution = Int(8)

    view = View(
        HGroup(Item('visible', show_label=False),
               Item('color', show_label=False), Item('opacity')))

    def __init__(self, view='points', has_norm=False, *args, **kwargs):
        """Init.

        Parameters
        ----------
        view : 'points' | 'cloud'
            Whether the view options should be tailored to individual points
            or a point cloud.
        has_norm : bool
            Whether a norm can be defined; adds view options based on point
            norms (default False).
        """
        assert view in ('points', 'cloud', 'arrow')
        self._view = view
        self._has_norm = bool(has_norm)
        super(PointObject, self).__init__(*args, **kwargs)

    def default_traits_view(self):  # noqa: D102
        color = Item('color', show_label=False)
        scale = Item('point_scale',
                     label='Size',
                     width=_SCALE_WIDTH,
                     editor=laggy_float_editor_headscale)
        orient = Item('orient_to_surface',
                      enabled_when='orientable and not project_to_surface',
                      tooltip='Orient points toward the surface')
        dist = Item('scale_by_distance',
                    enabled_when='orientable and not project_to_surface',
                    tooltip='Scale points by distance from the surface')
        mark = Item('mark_inside',
                    enabled_when='orientable and not project_to_surface',
                    tooltip='Mark points inside the surface using a different '
                    'color')
        if self._view == 'arrow':
            visible = Item('visible', label='Show', show_label=False)
            return View(HGroup(visible, scale, 'opacity', 'label', Spring()))
        elif self._view == 'points':
            visible = Item('visible', label='Show', show_label=True)
            views = (visible, color, scale, 'label')
        else:
            assert self._view == 'cloud'
            visible = Item('visible', show_label=False)
            views = (visible, color, scale)

        if not self._has_norm:
            return View(HGroup(*views))

        group2 = HGroup(dist,
                        Item('project_to_surface',
                             show_label=True,
                             enabled_when='projectable',
                             tooltip='Project points onto the surface '
                             '(for visualization, does not affect '
                             'fitting)'),
                        orient,
                        mark,
                        Spring(),
                        show_left=False)
        return View(HGroup(HGroup(*views), group2))

    @on_trait_change('label')
    def _show_labels(self, show):
        _toggle_mlab_render(self, False)
        while self.text3d:
            text = self.text3d.pop()
            text.remove()

        if show and len(self.src.data.points) > 0:
            fig = self.scene.mayavi_scene
            if self._view == 'arrow':  # for axes
                x, y, z = self.src.data.points[0]
                self.text3d.append(
                    text3d(x,
                           y,
                           z,
                           self.name,
                           scale=self.label_scale,
                           color=self.color,
                           figure=fig))
            else:
                for i, (x, y, z) in enumerate(np.array(self.src.data.points)):
                    self.text3d.append(
                        text3d(x,
                               y,
                               z,
                               ' %i' % i,
                               scale=self.label_scale,
                               color=self.color,
                               figure=fig))
        _toggle_mlab_render(self, True)

    @on_trait_change('visible')
    def _on_hide(self):
        if not self.visible:
            self.label = False

    @on_trait_change('scene.activated')
    def _plot_points(self):
        """Add the points to the mayavi pipeline"""
        if self.scene is None:
            return
        if hasattr(self.glyph, 'remove'):
            self.glyph.remove()
        if hasattr(self.src, 'remove'):
            self.src.remove()

        _toggle_mlab_render(self, False)
        x, y, z = self.points.T
        fig = self.scene.mayavi_scene
        scatter = pipeline.scalar_scatter(x, y, z, fig=fig)
        if not scatter.running:
            # this can occur sometimes during testing w/ui.dispose()
            return
        # fig.scene.engine.current_object is scatter
        mode = 'arrow' if self._view == 'arrow' else 'sphere'
        glyph = pipeline.glyph(scatter,
                               color=self.color,
                               figure=fig,
                               scale_factor=self.point_scale,
                               opacity=1.,
                               resolution=self.resolution,
                               mode=mode)
        glyph.actor.property.backface_culling = True
        glyph.glyph.glyph.vector_mode = 'use_normal'
        glyph.glyph.glyph.clamping = False
        if mode == 'arrow':
            glyph.glyph.glyph_source.glyph_position = 'tail'

        glyph.actor.mapper.color_mode = 'map_scalars'
        glyph.actor.mapper.scalar_mode = 'use_point_data'
        glyph.actor.mapper.use_lookup_table_scalar_range = False

        self.src = scatter
        self.glyph = glyph

        self.sync_trait('point_scale', self.glyph.glyph.glyph, 'scale_factor')
        self.sync_trait('color', self.glyph.actor.property, mutual=False)
        self.sync_trait('visible', self.glyph)
        self.sync_trait('opacity', self.glyph.actor.property)
        self.sync_trait('mark_inside', self.glyph.actor.mapper,
                        'scalar_visibility')
        self.on_trait_change(self._update_points, 'points')
        self._update_marker_scaling()
        self._update_marker_type()
        self._update_colors()
        _toggle_mlab_render(self, True)
        # self.scene.camera.parallel_scale = _scale

    def _nearest_default(self):
        return _DistanceQuery(np.zeros((1, 3)))

    def _get_nearest(self, proj_rr):
        idx = self.nearest.query(proj_rr)[1]
        proj_pts = apply_trans(self.project_to_trans, self.nearest.data[idx])
        proj_nn = apply_trans(self.project_to_trans,
                              self.check_inside.surf['nn'][idx],
                              move=False)
        return proj_pts, proj_nn

    @on_trait_change('points,project_to_trans,project_to_surface,mark_inside,'
                     'nearest')
    def _update_projections(self):
        """Update the styles of the plotted points."""
        if not hasattr(self.src, 'data'):
            return
        if self._view == 'arrow':
            self.src.data.point_data.normals = self.nn
            self.src.data.point_data.update()
            return
        # projections
        if len(self.nearest.data) <= 1 or len(self.points) == 0:
            return

        # Do the projections
        pts = self.points
        inv_trans = np.linalg.inv(self.project_to_trans)
        proj_rr = apply_trans(inv_trans, self.points)
        proj_pts, proj_nn = self._get_nearest(proj_rr)
        vec = pts - proj_pts  # point to the surface
        if self.project_to_surface:
            pts = proj_pts
        nn = proj_nn
        if self.mark_inside and not self.project_to_surface:
            scalars = (~self.check_inside(proj_rr, verbose=False)).astype(int)
        else:
            scalars = np.ones(len(pts))
        # With this, a point exactly on the surface is of size point_scale
        dist = np.linalg.norm(vec, axis=-1, keepdims=True)
        self.src.data.point_data.normals = (250 * dist + 1) * nn
        self.src.data.point_data.scalars = scalars
        self.glyph.actor.mapper.scalar_range = [0., 1.]
        self.src.data.points = pts  # projection can change this
        self.src.data.point_data.update()

    @on_trait_change('color,inside_color')
    def _update_colors(self):
        if self.glyph is None:
            return
        # inside_color is the surface color, let's try to get far
        # from that
        inside = np.array(self.inside_color)
        # if it's too close to gray, just use black:
        if np.mean(np.abs(inside - 0.5)) < 0.2:
            inside.fill(0.)
        else:
            inside = 1 - inside
        colors = np.array([tuple(inside) + (1, ),
                           tuple(self.color) + (1, )]) * 255.
        self.glyph.module_manager.scalar_lut_manager.lut.table = colors

    @on_trait_change('project_to_surface,orient_to_surface')
    def _update_marker_type(self):
        # not implemented for arrow
        if self.glyph is None or self._view == 'arrow':
            return
        defaults = DEFAULTS['coreg']
        gs = self.glyph.glyph.glyph_source
        res = getattr(gs.glyph_source, 'theta_resolution',
                      getattr(gs.glyph_source, 'resolution', None))
        if self.project_to_surface or self.orient_to_surface:
            gs.glyph_source = tvtk.CylinderSource()
            gs.glyph_source.height = defaults['eegp_height']
            gs.glyph_source.center = (0., -defaults['eegp_height'], 0)
            gs.glyph_source.resolution = res
        else:
            gs.glyph_source = tvtk.SphereSource()
            gs.glyph_source.phi_resolution = res
            gs.glyph_source.theta_resolution = res

    @on_trait_change('scale_by_distance,project_to_surface')
    def _update_marker_scaling(self):
        if self.glyph is None:
            return
        if self.scale_by_distance and not self.project_to_surface:
            self.glyph.glyph.scale_mode = 'scale_by_vector'
        else:
            self.glyph.glyph.scale_mode = 'data_scaling_off'

    def _resolution_changed(self, new):
        if not self.glyph:
            return
        gs = self.glyph.glyph.glyph_source.glyph_source
        if isinstance(gs, tvtk.SphereSource):
            gs.phi_resolution = new
            gs.theta_resolution = new
        elif isinstance(gs, tvtk.CylinderSource):
            gs.resolution = new
        else:  # ArrowSource
            gs.tip_resolution = new
            gs.shaft_resolution = new

    @cached_property
    def _get_orientable(self):
        return len(self.nearest.data) > 1
Exemplo n.º 16
0
class Legend(AbstractOverlay):
    """ A legend for a plot.
    """
    # The font to use for the legend text.
    font = KivaFont("modern 12")

    # The amount of space between the content of the legend and the border.
    border_padding = Int(10)

    # The border is visible (overrides Enable Component).
    border_visible = True

    # The color of the text labels
    color = black_color_trait

    # The background color of the legend (overrides AbstractOverlay).
    bgcolor = white_color_trait

    # The position of the legend with respect to its overlaid component.  (This
    # attribute applies only if the legend is used as an overlay.)
    #
    # * ur = Upper Right
    # * ul = Upper Left
    # * ll = Lower Left
    # * lr = Lower Right
    align = Enum("ur", "ul", "ll", "lr")

    # The amount of space between legend items.
    line_spacing = Int(3)

    # The size of the icon or marker area drawn next to the label.
    icon_bounds = List([24, 24])

    # Amount of spacing between each label and its icon.
    icon_spacing = Int(5)

    # Map of labels (strings) to plot instances or lists of plot instances.  The
    # Legend determines the appropriate rendering of each plot's marker/line.
    plots = Dict

    # The list of labels to show and the order to show them in.  If this
    # list is blank, then the keys of self.plots is used and displayed in
    # alphabetical order.  Otherwise, only the items in the **labels**
    # list are drawn in the legend.  Labels are ordered from top to bottom.
    labels = List

    # Whether or not to hide plots that are not visible.  (This is checked during
    # layout.)  This option *will* filter out the items in **labels** above, so
    # if you absolutely, positively want to set the items that will always
    # display in the legend, regardless of anything else, then you should turn
    # this option off.  Otherwise, it usually makes sense that a plot renderer
    # that is not visible will also not be in the legend.
    hide_invisible_plots = Bool(True)

    # If hide_invisible_plots is False, we can still choose to render the names
    # of invisible plots with an alpha.
    invisible_plot_alpha = Float(0.33)

    # The renderer that draws the icons for the legend.
    composite_icon_renderer = Instance(AbstractCompositeIconRenderer)

    # Action that the legend takes when it encounters a plot whose icon it
    # cannot render:
    #
    # * 'skip': skip it altogether and don't render its name
    # * 'blank': render the name but leave the icon blank (color=self.bgcolor)
    # * 'questionmark': render a "question mark" icon
    error_icon = Enum("skip", "blank", "questionmark")

    # Should the legend clip to the bounds it needs, or to its parent?
    clip_to_component = Bool(False)

    # The legend is not resizable (overrides PlotComponent).
    resizable = "hv"

    # An optional title string to show on the legend.
    title = Str('')

    # If True, title is at top, if False then at bottom.
    title_at_top = Bool(True)

    # The legend draws itself as in one pass when its parent is drawing
    # the **draw_layer** (overrides PlotComponent).
    unified_draw = True
    # The legend is drawn on the overlay layer of its parent (overrides
    # PlotComponent).
    draw_layer = "overlay"

    #------------------------------------------------------------------------
    # Private Traits
    #------------------------------------------------------------------------

    # A cached list of Label instances
    _cached_labels = List

    # A cached array of label sizes.
    _cached_label_sizes = ArrayOrNone()

    # A cached list of label names.
    _cached_label_names = CList

    # A list of the visible plots.  Each plot corresponds to the label at
    # the same index in _cached_label_names.  This list does not necessarily
    # correspond to self.plots.value() because it is sorted according to
    # the plot name and it potentially excludes invisible plots.
    _cached_visible_plots = CList

    # A cached array of label positions relative to the legend's origin
    _cached_label_positions = ArrayOrNone()

    def is_in(self, x, y):
        """ overloads from parent class because legend alignment
            and padding does not cooperatate with the basic implementation

            This may just be caused byt a questionable implementation of the
            legend tool, but it works by adjusting the padding. The Component
            class implementation of is_in uses the outer positions which
            includes the padding
        """
        in_x = (x >= self.x) and (x <= self.x + self.width)
        in_y = (y >= self.y) and (y <= self.y + self.height)

        return in_x and in_y

    def overlay(self, component, gc, view_bounds=None, mode="normal"):
        """ Draws this component overlaid on another component.

        Implements AbstractOverlay.
        """
        self.do_layout()
        valign, halign = self.align
        if valign == "u":
            y = component.y2 - self.outer_height
        else:
            y = component.y
        if halign == "r":
            x = component.x2 - self.outer_width
        else:
            x = component.x
        self.outer_position = [x, y]

        if self.clip_to_component:
            c = self.component
            with gc:
                gc.clip_to_rect(c.x, c.y, c.width, c.height)
                PlotComponent._draw(self, gc, view_bounds, mode)
        else:
            PlotComponent._draw(self, gc, view_bounds, mode)

        return

    # The following two methods implement the functionality of the Legend
    # to act as a first-class component instead of merely as an overlay.
    # The make the Legend use the normal PlotComponent render methods when
    # it does not have a .component attribute, so that it can have its own
    # overlays (e.g. a PlotLabel).
    #
    # The core legend rendering method is named _draw_as_overlay() so that
    # it can be called from _draw_plot() when the Legend is not an overlay,
    # and from _draw_overlay() when the Legend is an overlay.

    def _draw_plot(self, gc, view_bounds=None, mode="normal"):
        if self.component is None:
            self._draw_as_overlay(gc, view_bounds, mode)
        return

    def _draw_overlay(self, gc, view_bounds=None, mode="normal"):
        if self.component is not None:
            self._draw_as_overlay(gc, view_bounds, mode)
        else:
            PlotComponent._draw_overlay(self, gc, view_bounds, mode)
        return

    def _draw_as_overlay(self, gc, view_bounds=None, mode="normal"):
        """ Draws the overlay layer of a component.

        Overrides PlotComponent.
        """
        # Determine the position we are going to draw at from our alignment
        # corner and the corresponding outer_padding parameters.  (Position
        # refers to the lower-left corner of our border.)

        # First draw the border, if necesssary.  This sort of duplicates
        # the code in PlotComponent._draw_overlay, which is unfortunate;
        # on the other hand, overlays of overlays seem like a rather obscure
        # feature.

        with gc:
            gc.clip_to_rect(int(self.x), int(self.y), int(self.width),
                            int(self.height))
            edge_space = self.border_width + self.border_padding
            icon_width, icon_height = self.icon_bounds

            icon_x = self.x + edge_space
            text_x = icon_x + icon_width + self.icon_spacing
            y = self.y2 - edge_space

            if self._cached_label_positions is not None:
                if len(self._cached_label_positions) > 0:
                    self._cached_label_positions[:, 0] = icon_x

            for i, label_name in enumerate(self._cached_label_names):
                # Compute the current label's position
                label_height = self._cached_label_sizes[i][1]
                y -= label_height
                self._cached_label_positions[i][1] = y

                # Try to render the icon
                icon_y = y + (label_height - icon_height) / 2
                #plots = self.plots[label_name]
                plots = self._cached_visible_plots[i]
                render_args = (gc, icon_x, icon_y, icon_width, icon_height)

                try:
                    if isinstance(plots, list) or isinstance(plots, tuple):
                        # TODO: How do we determine if a *group* of plots is
                        # visible or not?  For now, just look at the first one
                        # and assume that applies to all of them
                        if not plots[0].visible:
                            # TODO: the get_alpha() method isn't supported on the Mac kiva backend
                            #old_alpha = gc.get_alpha()
                            old_alpha = 1.0
                            gc.set_alpha(self.invisible_plot_alpha)
                        else:
                            old_alpha = None
                        if len(plots) == 1:
                            plots[0]._render_icon(*render_args)
                        else:
                            self.composite_icon_renderer.render_icon(
                                plots, *render_args)
                    elif plots is not None:
                        # Single plot
                        if not plots.visible:
                            #old_alpha = gc.get_alpha()
                            old_alpha = 1.0
                            gc.set_alpha(self.invisible_plot_alpha)
                        else:
                            old_alpha = None
                        plots._render_icon(*render_args)
                    else:
                        old_alpha = None  # Or maybe 1.0?

                    icon_drawn = True
                except:
                    icon_drawn = self._render_error(*render_args)

                if icon_drawn:
                    # Render the text
                    gc.translate_ctm(text_x, y)
                    gc.set_antialias(0)
                    self._cached_labels[i].draw(gc)
                    gc.set_antialias(1)
                    gc.translate_ctm(-text_x, -y)

                    # Advance y to the next label's baseline
                    y -= self.line_spacing
                if old_alpha is not None:
                    gc.set_alpha(old_alpha)

        return

    def _render_error(self, gc, icon_x, icon_y, icon_width, icon_height):
        """ Renders an error icon or performs some other action when a
        plot is unable to render its icon.

        Returns True if something was actually drawn (and hence the legend
        needs to advance the line) or False if nothing was drawn.
        """
        if self.error_icon == "skip":
            return False
        elif self.error_icon == "blank" or self.error_icon == "questionmark":
            with gc:
                gc.set_fill_color(self.bgcolor_)
                gc.rect(icon_x, icon_y, icon_width, icon_height)
                gc.fill_path()
            return True
        else:
            return False

    def get_preferred_size(self):
        """
        Computes the size and position of the legend based on the maximum size of
        the labels, the alignment, and position of the component to overlay.
        """
        # Gather the names of all the labels we will create
        if len(self.plots) == 0:
            return [0, 0]

        plot_names, visible_plots = list(
            sm.map(list, sm.zip(*sorted(self.plots.items()))))
        label_names = self.labels
        if len(label_names) == 0:
            if len(self.plots) > 0:
                label_names = plot_names
            else:
                self._cached_labels = []
                self._cached_label_sizes = []
                self._cached_label_names = []
                self._cached_visible_plots = []
                self.outer_bounds = [0, 0]
                return [0, 0]

        if self.hide_invisible_plots:
            visible_labels = []
            visible_plots = []
            for name in label_names:
                # If the user set self.labels, there might be a bad value,
                # so ensure that each name is actually in the plots dict.
                if name in self.plots:
                    val = self.plots[name]
                    # Rather than checking for a list/TraitListObject/etc., we just check
                    # for the attribute first
                    if hasattr(val, 'visible'):
                        if val.visible:
                            visible_labels.append(name)
                            visible_plots.append(val)
                    else:
                        # If we have a list of renderers, add the name if any of them are
                        # visible
                        for renderer in val:
                            if renderer.visible:
                                visible_labels.append(name)
                                visible_plots.append(val)
                                break
            label_names = visible_labels

        # Create the labels
        labels = [self._create_label(text) for text in label_names]

        # For the legend title
        if self.title_at_top:
            labels.insert(0, self._create_label(self.title))
            label_names.insert(0, 'Legend Label')
            visible_plots.insert(0, None)
        else:
            labels.append(self._create_label(self.title))
            label_names.append(self.title)
            visible_plots.append(None)

        # We need a dummy GC in order to get font metrics
        dummy_gc = font_metrics_provider()
        label_sizes = array(
            [label.get_width_height(dummy_gc) for label in labels])

        if len(label_sizes) > 0:
            max_label_width = max(label_sizes[:, 0])
            total_label_height = sum(
                label_sizes[:, 1]) + (len(label_sizes) - 1) * self.line_spacing
        else:
            max_label_width = 0
            total_label_height = 0

        legend_width = max_label_width + self.icon_spacing + self.icon_bounds[0] \
                        + self.hpadding + 2*self.border_padding
        legend_height = total_label_height + self.vpadding + 2 * self.border_padding

        self._cached_labels = labels
        self._cached_label_sizes = label_sizes
        self._cached_label_positions = zeros_like(label_sizes)
        self._cached_label_names = label_names
        self._cached_visible_plots = visible_plots

        if "h" not in self.resizable:
            legend_width = self.outer_width
        if "v" not in self.resizable:
            legend_height = self.outer_height
        return [legend_width, legend_height]

    def get_label_at(self, x, y):
        """ Returns the label object at (x,y) """
        for i, pos in enumerate(self._cached_label_positions):
            size = self._cached_label_sizes[i]
            corner = pos + size
            if (pos[0] <= x <= corner[0]) and (pos[1] <= y <= corner[1]):
                return self._cached_labels[i]
        else:
            return None

    def _do_layout(self):
        if self.component is not None or len(self._cached_labels) == 0 or \
                self._cached_label_sizes is None or len(self._cached_label_names) == 0:
            width, height = self.get_preferred_size()
            self.outer_bounds = [width, height]
        return

    def _create_label(self, text):
        """ Returns a new Label instance for the given text.  Subclasses can
        override this method to customize the creation of labels.
        """
        return Label(text=text,
                     font=self.font,
                     margin=0,
                     color=self.color_,
                     bgcolor="transparent",
                     border_width=0)

    def _composite_icon_renderer_default(self):
        return CompositeIconRenderer()

    #-- trait handlers --------------------------------------------------------
    def _anytrait_changed(self, name, old, new):
        if name in ("font", "border_padding", "padding", "line_spacing",
                    "icon_bounds", "icon_spacing", "labels", "plots",
                    "plots_items", "labels_items", "border_width", "align",
                    "position", "position_items", "bounds", "bounds_items",
                    "label_at_top"):
            self._layout_needed = True
        if name == "color":
            self.get_preferred_size()
        return

    def _plots_changed(self):
        """ Invalidate the caches.
        """
        self._cached_labels = []
        self._cached_label_sizes = None
        self._cached_label_names = []
        self._cached_visible_plots = []
        self._cached_label_positions = None

    def _title_at_top_changed(self, old, new):
        """ Trait handler for when self.title_at_top changes. """
        if old == True:
            indx = 0
        else:
            indx = -1
        if old != None:
            self._cached_labels.pop(indx)
            self._cached_label_names.pop(indx)
            self._cached_visible_plots.pop(indx)

        # For the legend title
        if self.title_at_top:
            self._cached_labels.insert(0, self._create_label(self.title))
            self._cached_label_names.insert(0, '__legend_label__')
            self._cached_visible_plots.insert(0, None)
        else:
            self._cached_labels.append(self._create_label(self.title))
            self._cached_label_names.append(self.title)
            self._cached_visible_plots.append(None)