class Stroke(HasStrictTraits): #: Path vertices; shape => (# of points) x 2 vertices = ArrayOrNone(shape=(None, 2), dtype=np.float64) #: Codes for path vertices codes = List(String) #: Median vertices; shape => # of points x 2 medians = ArrayOrNone(shape=(None, 2), dtype=np.float64) @classmethod def from_hanzi_data(cls, path_data, median_data, median_offset=0): vertices, codes = svg_parse(path_data) median_data[0] = adjust_point_by(median_data[:2], 0, -median_offset) traits = {'vertices': vertices, 'codes': codes, 'medians': median_data} return cls(**traits) def to_mpl_path(self): commands = { 'M': (Path.MOVETO, ), 'L': (Path.LINETO, ), 'Q': (Path.CURVE3, ) * 2, 'C': (Path.CURVE4, ) * 3, 'Z': (Path.CLOSEPOLY, ) } mpl_codes = [] for code in self.codes: mpl_codes.extend(commands[code]) return Path(self.vertices, mpl_codes)
class Foo(HasTraits): maybe_array = ArrayOrNone maybe_float_array = ArrayOrNone(dtype=float) maybe_two_d_array = ArrayOrNone(shape=(None, None)) maybe_array_with_default = ArrayOrNone(value=[1, 2, 3]) maybe_array_no_compare = ArrayOrNone(comparison_mode=NO_COMPARE)
class FiducialsSource(HasTraits): """Expose points of a given fiducials fif file. Parameters ---------- file : File Path to a fif file with fiducials (*.fif). Attributes ---------- points : Array, shape = (n_points, 3) Fiducials file points. """ file = File(filter=[fid_wildcard]) fname = Property(depends_on='file') points = Property(ArrayOrNone, depends_on='file') mni_points = ArrayOrNone(float, shape=(3, 3)) def _get_fname(self): return op.basename(self.file) @cached_property def _get_points(self): if not op.exists(self.file): return self.mni_points # can be None try: return _fiducial_coords(*read_fiducials(self.file)) except Exception as err: error( None, "Error reading fiducials from %s: %s (See terminal " "for more information)" % (self.fname, str(err)), "Error Reading Fiducials") self.reset_traits(['file']) raise
class Gradient(HasStrictTraits): """ A color gradient. """ #: The sequence of color stops for the gradient. stops = List(Instance(ColorStop)) #: A trait which fires when the gradient is updated. updated = Event() #: A temporary cache for the stop array. _array_cache = ArrayOrNone() def to_array(self): """ Return a sorted list of stop arrays. This is the raw form of the stops required by Kiva. Returns ------- stop_array_list : arrays A list of array of (offset, r, b, b, a) values corresponding to the color stops. This array should not be mutated. """ if self._array_cache is None: self._array_cache = array([ stop.to_array() for stop in sorted(self.stops, key=attrgetter("offset")) ]) return self._array_cache @observe("stops.items.updated") def observe_stops(self, event): self._array_cache = None self.updated = True def _stops_default(self): return [ ColorStop(offset=0.0, color="white"), ColorStop(offset=1.0, color="black"), ]
class RectangularSelection(LassoSelection): """ A lasso selection tool whose selection shape is rectangular """ #: The first click. This represents a corner of the rectangle. first_corner = ArrayOrNone(shape=(2, )) def selecting_mouse_move(self, event): """ This function is the same as the super except that it injects `_make_rectangle` as the `_active_selection` assignment. """ # Translate the event's location to be relative to this container xform = self.component.get_event_transform(event) event.push_transform(xform, caller=self) new_point = self._map_data(np.array((event.x, event.y))) if self.first_corner is None: self.first_corner = new_point self._active_selection = self._make_rectangle(self.first_corner, new_point) self.updated = True if self.incremental_select: self._update_selection() # Report None for the previous selections self.trait_property_changed("disjoint_selections", None) def selecting_mouse_up(self, event): super(RectangularSelection, self).selecting_mouse_up(event) # Clear the first click self.first_corner = None def _make_rectangle(self, p1, p2): """ Makes an array that represents that path that follows the corner points of the rectangle with two corners p1 and p2: *-----p2 | | p1----* """ return np.array([p1, [p1[0], p2[1]], p2, [p2[0], p1[1]]])
class Bar(HasTraits): bad_array = ArrayOrNone(shape=(None, None), value=[1, 2, 3])
class Bar(HasTraits): unsafe_f32 = ArrayOrNone(dtype="float32") safe_f32 = ArrayOrNone(dtype="float32", casting="safe")
class FooBar(HasTraits): foo = ArrayOrNone(value=test_default) bar = ArrayOrNone(value=test_default)
class LabelAxis(PlotAxis): """ An axis whose ticks are labeled with text instead of numbers. """ #: List of labels to use on tick marks. labels = List(Str) #: The angle of rotation of the label. Only multiples of 90 are supported. label_rotation = Float(0) #: List of indices of ticks positions = ArrayOrNone() def _compute_tick_positions(self, gc, component=None): """ Calculates the positions for the tick marks. Overrides PlotAxis. """ if (self.mapper is None): self._reset_cache() self._cache_valid = True return datalow = self.mapper.range.low datahigh = self.mapper.range.high screenhigh = self.mapper.high_pos screenlow = self.mapper.low_pos if (datalow == datahigh) or (screenlow == screenhigh) or \ (datalow in [inf, -inf]) or (datahigh in [inf, -inf]): self._reset_cache() self._cache_valid = True return if not self.tick_generator: return # Get a set of ticks from the tick generator. tick_list = array( self.tick_generator.get_ticks(datalow, datahigh, datalow, datahigh, self.tick_interval), float64) # Find all the positions in the current range. pos_index = [] pos = [] pos_min = None pos_max = None for i, position in enumerate(self.positions): if datalow <= position <= datahigh: pos_max = max(position, pos_max) if pos_max is not None else position pos_min = min(position, pos_min) if pos_min is not None else position pos_index.append(i) pos.append(position) if len(pos_index) == 0: # No positions currently visible. self._tick_positions = [] self._tick_label_positions = [] self._tick_label_list = [] return # Use the ticks generated by the tick generator as a guide for selecting # the positions to be displayed. tick_indices = unique(searchsorted(pos, tick_list)) tick_indices = tick_indices[tick_indices < len(pos)] tick_positions = take(pos, tick_indices) self._tick_label_list = take(self.labels, take(pos_index, tick_indices)) if datalow > datahigh: raise RuntimeError( "DataRange low is greater than high; unable to compute axis ticks." ) mapped_label_positions = [((self.mapper.map_screen(pos)-screenlow) / \ (screenhigh-screenlow)) for pos in tick_positions] self._tick_positions = [self._axis_vector*tickpos + self._origin_point \ for tickpos in mapped_label_positions] self._tick_label_positions = self._tick_positions return def _compute_labels(self, gc): """Generates the labels for tick marks. Overrides PlotAxis. """ try: self.ticklabel_cache = [] for text in self._tick_label_list: ticklabel = Label(text=text, font=self.tick_label_font, color=self.tick_label_color, rotate_angle=self.label_rotation) self.ticklabel_cache.append(ticklabel) self._tick_label_bounding_boxes = [ array(ticklabel.get_bounding_box(gc), float64) for ticklabel in self.ticklabel_cache ] except: print_exc() return
class MarkerPointDest(MarkerPoints): # noqa: D401 """MarkerPoints subclass that serves for derived points.""" src1 = Instance(MarkerPointSource) src2 = Instance(MarkerPointSource) name = Property(Str, depends_on='src1.name,src2.name') dir = Property(Str, depends_on='src1.dir,src2.dir') points = Property(ArrayOrNone(float, (5, 3)), depends_on=[ 'method', 'src1.points', 'src1.use', 'src2.points', 'src2.use' ]) enabled = Property(Bool, depends_on=['points']) method = Enum('Transform', 'Average', desc="Transform: estimate a rotation" "/translation from mrk1 to mrk2; Average: use the average " "of the mrk1 and mrk2 coordinates for each point.") view = View( VGroup(Item('method', style='custom'), Item('save_as', enabled_when='can_save', show_label=False))) @cached_property def _get_dir(self): return self.src1.dir @cached_property def _get_name(self): n1 = self.src1.name n2 = self.src2.name if not n1: if n2: return n2 else: return '' elif not n2: return n1 if n1 == n2: return n1 i = 0 l1 = len(n1) - 1 l2 = len(n1) - 2 while n1[i] == n2[i]: if i == l1: return n1 elif i == l2: return n2 i += 1 return n1[:i] @cached_property def _get_enabled(self): return np.any(self.points) @cached_property def _get_points(self): # in case only one or no source is enabled if not (self.src1 and self.src1.enabled): if (self.src2 and self.src2.enabled): return self.src2.points else: return np.zeros((5, 3)) elif not (self.src2 and self.src2.enabled): return self.src1.points # Average method if self.method == 'Average': if len(np.union1d(self.src1.use, self.src2.use)) < 5: error(None, "Need at least one source for each point.", "Marker Average Error") return np.zeros((5, 3)) pts = (self.src1.points + self.src2.points) / 2. for i in np.setdiff1d(self.src1.use, self.src2.use): pts[i] = self.src1.points[i] for i in np.setdiff1d(self.src2.use, self.src1.use): pts[i] = self.src2.points[i] return pts # Transform method idx = np.intersect1d(np.array(self.src1.use), np.array(self.src2.use), assume_unique=True) if len(idx) < 3: error(None, "Need at least three shared points for trans" "formation.", "Marker Interpolation Error") return np.zeros((5, 3)) src_pts = self.src1.points[idx] tgt_pts = self.src2.points[idx] est = fit_matched_points(src_pts, tgt_pts, out='params') rot = np.array(est[:3]) / 2. tra = np.array(est[3:]) / 2. if len(self.src1.use) == 5: trans = np.dot(translation(*tra), rotation(*rot)) pts = apply_trans(trans, self.src1.points) elif len(self.src2.use) == 5: trans = np.dot(translation(*-tra), rotation(*-rot)) pts = apply_trans(trans, self.src2.points) else: trans1 = np.dot(translation(*tra), rotation(*rot)) pts = apply_trans(trans1, self.src1.points) trans2 = np.dot(translation(*-tra), rotation(*-rot)) for i in np.setdiff1d(self.src2.use, self.src1.use): pts[i] = apply_trans(trans2, self.src2.points[i]) return pts
class PlotAxis(AbstractOverlay): """ The PlotAxis is a visual component that can be rendered on its own as a standalone component or attached as an overlay to another component. (To attach it as an overlay, set its **component** attribute.) When it is attached as an overlay, it draws into the padding around the component. """ #: The mapper that drives this axis. mapper = Instance(AbstractMapper) #: Keep an origin for plots that aren't attached to a component origin = Enum("bottom left", "top left", "bottom right", "top right") #: The text of the axis title. title = Trait('', Str, Unicode) #May want to add PlotLabel option #: The font of the title. title_font = KivaFont('modern 12') #: The spacing between the axis line and the title title_spacing = Trait('auto', 'auto', Float) #: The color of the title. title_color = ColorTrait("black") #: The angle of the title, in degrees, from horizontal line title_angle = Float(0.) #: The thickness (in pixels) of each tick. tick_weight = Float(1.0) #: The color of the ticks. tick_color = ColorTrait("black") #: The font of the tick labels. tick_label_font = KivaFont('modern 10') #: The color of the tick labels. tick_label_color = ColorTrait("black") #: The rotation of the tick labels. tick_label_rotate_angle = Float(0) #: Whether to align to corners or edges (corner is better for 45 degree rotation) tick_label_alignment = Enum('edge', 'corner') #: The margin around the tick labels. tick_label_margin = Int(2) #: The distance of the tick label from the axis. tick_label_offset = Float(8.) #: Whether the tick labels appear to the inside or the outside of the plot area tick_label_position = Enum("outside", "inside") #: A callable that is passed the numerical value of each tick label and #: that returns a string. tick_label_formatter = Callable(DEFAULT_TICK_FORMATTER) #: The number of pixels by which the ticks extend into the plot area. tick_in = Int(5) #: The number of pixels by which the ticks extend into the label area. tick_out = Int(5) #: Are ticks visible at all? tick_visible = Bool(True) #: The dataspace interval between ticks. tick_interval = Trait('auto', 'auto', Float) #: A callable that implements the AbstractTickGenerator interface. tick_generator = Instance(AbstractTickGenerator) #: The location of the axis relative to the plot. This determines where #: the axis title is located relative to the axis line. orientation = Enum("top", "bottom", "left", "right") #: Is the axis line visible? axis_line_visible = Bool(True) #: The color of the axis line. axis_line_color = ColorTrait("black") #: The line thickness (in pixels) of the axis line. axis_line_weight = Float(1.0) #: The dash style of the axis line. axis_line_style = LineStyle('solid') #: A special version of the axis line that is more useful for geophysical #: plots. small_haxis_style = Bool(False) #: Does the axis ensure that its end labels fall within its bounding area? ensure_labels_bounded = Bool(False) #: Does the axis prevent the ticks from being rendered outside its bounds? #: This flag is off by default because the standard axis *does* render ticks #: that encroach on the plot area. ensure_ticks_bounded = Bool(False) #: Fired when the axis's range bounds change. updated = Event #------------------------------------------------------------------------ # Override default values of inherited traits #------------------------------------------------------------------------ #: Background color (overrides AbstractOverlay). Axes usually let the color of #: the container show through. bgcolor = ColorTrait("transparent") #: Dimensions that the axis is resizable in (overrides PlotComponent). #: Typically, axes are resizable in both dimensions. resizable = "hv" #------------------------------------------------------------------------ # Private Traits #------------------------------------------------------------------------ # Cached position calculations _tick_list = List # These are caches of their respective positions _tick_positions = ArrayOrNone() _tick_label_list = ArrayOrNone() _tick_label_positions = ArrayOrNone() _tick_label_bounding_boxes = List _major_axis_size = Float _minor_axis_size = Float _major_axis = Array _title_orientation = Array _title_angle = Float _origin_point = Array _inside_vector = Array _axis_vector = Array _axis_pixel_vector = Array _end_axis_point = Array ticklabel_cache = List _cache_valid = Bool(False) #------------------------------------------------------------------------ # Public methods #------------------------------------------------------------------------ def __init__(self, component=None, **kwargs): # TODO: change this back to a factory in the instance trait some day self.tick_generator = DefaultTickGenerator() # Override init so that our component gets set last. We want the # _component_changed() event handler to get run last. super(PlotAxis, self).__init__(**kwargs) if component is not None: self.component = component def invalidate(self): """ Invalidates the pre-computed layout and scaling data. """ self._reset_cache() self.invalidate_draw() return def traits_view(self): """ Returns a View instance for use with Traits UI. This method is called automatically be the Traits framework when .edit_traits() is invoked. """ from .axis_view import AxisView return AxisView #------------------------------------------------------------------------ # PlotComponent and AbstractOverlay interface #------------------------------------------------------------------------ def _do_layout(self, *args, **kw): """ Tells this component to do layout at a given size. Overrides Component. """ if self.use_draw_order and self.component is not None: self._layout_as_overlay(*args, **kw) else: super(PlotAxis, self)._do_layout(*args, **kw) return def overlay(self, component, gc, view_bounds=None, mode='normal'): """ Draws this component overlaid on another component. Overrides AbstractOverlay. """ if not self.visible: return self._draw_component(gc, view_bounds, mode, component) return def _draw_overlay(self, gc, view_bounds=None, mode='normal'): """ Draws the overlay layer of a component. Overrides PlotComponent. """ self._draw_component(gc, view_bounds, mode) return def _draw_component(self, gc, view_bounds=None, mode='normal', component=None): """ Draws the component. This method is preserved for backwards compatibility. Overrides PlotComponent. """ if not self.visible: return if not self._cache_valid: if component is not None: self._calculate_geometry_overlay(component) else: self._calculate_geometry() self._compute_tick_positions(gc, component) self._compute_labels(gc) with gc: # slight optimization: if we set the font correctly on the # base gc before handing it in to our title and tick labels, # their set_font() won't have to do any work. gc.set_font(self.tick_label_font) if self.axis_line_visible: self._draw_axis_line(gc, self._origin_point, self._end_axis_point) if self.title: self._draw_title(gc) self._draw_ticks(gc) self._draw_labels(gc) self._cache_valid = True return #------------------------------------------------------------------------ # Private draw routines #------------------------------------------------------------------------ def _layout_as_overlay(self, size=None, force=False): """ Lays out the axis as an overlay on another component. """ if self.component is not None: if self.orientation in ("left", "right"): self.y = self.component.y self.height = self.component.height if self.orientation == "left": self.width = self.component.padding_left self.x = self.component.outer_x elif self.orientation == "right": self.width = self.component.padding_right self.x = self.component.x2 + 1 else: self.x = self.component.x self.width = self.component.width if self.orientation == "bottom": self.height = self.component.padding_bottom self.y = self.component.outer_y elif self.orientation == "top": self.height = self.component.padding_top self.y = self.component.y2 + 1 return def _draw_axis_line(self, gc, startpoint, endpoint): """ Draws the line for the axis. """ with gc: gc.set_antialias(0) gc.set_line_width(self.axis_line_weight) gc.set_stroke_color(self.axis_line_color_) gc.set_line_dash(self.axis_line_style_) gc.move_to(*around(startpoint)) gc.line_to(*around(endpoint)) gc.stroke_path() return def _draw_title(self, gc, label=None, axis_offset=None): """ Draws the title for the axis. """ if label is None: title_label = Label(text=self.title, font=self.title_font, color=self.title_color, rotate_angle=self.title_angle) else: title_label = label # get the _rotated_ bounding box of the label tl_bounds = array(title_label.get_bounding_box(gc), float64) text_center_to_corner = -tl_bounds/2.0 # which axis are we moving away from the axis line along? axis_index = self._major_axis.argmin() if self.title_spacing != 'auto': axis_offset = self.title_spacing if (self.title_spacing) and (axis_offset is None ): if not self.ticklabel_cache: axis_offset = 25 else: axis_offset = max([l._bounding_box[axis_index] for l in self.ticklabel_cache]) * 1.3 offset = (self._origin_point+self._end_axis_point)/2 axis_dist = self.tick_out + tl_bounds[axis_index]/2.0 + axis_offset offset -= self._inside_vector * axis_dist offset += text_center_to_corner gc.translate_ctm(*offset) title_label.draw(gc) gc.translate_ctm(*(-offset)) return def _draw_ticks(self, gc): """ Draws the tick marks for the axis. """ if not self.tick_visible: return gc.set_stroke_color(self.tick_color_) gc.set_line_width(self.tick_weight) gc.set_antialias(False) gc.begin_path() tick_in_vector = self._inside_vector*self.tick_in tick_out_vector = self._inside_vector*self.tick_out for tick_pos in self._tick_positions: gc.move_to(*(tick_pos + tick_in_vector)) gc.line_to(*(tick_pos - tick_out_vector)) gc.stroke_path() return def _draw_labels(self, gc): """ Draws the tick labels for the axis. """ # which axis are we moving away from the axis line along? axis_index = self._major_axis.argmin() inside_vector = self._inside_vector if self.tick_label_position == "inside": inside_vector = -inside_vector for i in range(len(self._tick_label_positions)): #We want a more sophisticated scheme than just 2 decimals all the time ticklabel = self.ticklabel_cache[i] tl_bounds = self._tick_label_bounding_boxes[i] #base_position puts the tick label at a point where the vector #extending from the tick mark inside 8 units #just touches the rectangular bounding box of the tick label. #Note: This is not necessarily optimal for non #horizontal/vertical axes. More work could be done on this. base_position = self._tick_label_positions[i].copy() axis_dist = self.tick_label_offset + tl_bounds[axis_index]/2.0 base_position -= inside_vector * axis_dist base_position -= tl_bounds/2.0 if self.tick_label_alignment == 'corner': if self.orientation in ("top", "bottom"): base_position[0] += tl_bounds[0]/2.0 elif self.orientation == "left": base_position[1] -= tl_bounds[1]/2.0 elif self.orientation == "right": base_position[1] += tl_bounds[1]/2.0 if self.ensure_labels_bounded: bound_idx = self._major_axis.argmax() if i == 0: base_position[bound_idx] = max(base_position[bound_idx], self._origin_point[bound_idx]) elif i == len(self._tick_label_positions)-1: base_position[bound_idx] = min(base_position[bound_idx], self._end_axis_point[bound_idx] - \ tl_bounds[bound_idx]) tlpos = around(base_position) gc.translate_ctm(*tlpos) ticklabel.draw(gc) gc.translate_ctm(*(-tlpos)) return #------------------------------------------------------------------------ # Private methods for computing positions and layout #------------------------------------------------------------------------ def _reset_cache(self): """ Clears the cached tick positions, labels, and label positions. """ self._tick_positions = [] self._tick_label_list = [] self._tick_label_positions = [] return def _compute_tick_positions(self, gc, overlay_component=None): """ Calculates the positions for the tick marks. """ if (self.mapper is None): self._reset_cache() self._cache_valid = True return datalow = self.mapper.range.low datahigh = self.mapper.range.high screenhigh = self.mapper.high_pos screenlow = self.mapper.low_pos if overlay_component is not None: origin = getattr(overlay_component, 'origin', 'bottom left') else: origin = self.origin if self.orientation in ("top", "bottom"): if "right" in origin: flip_from_gc = True else: flip_from_gc = False elif self.orientation in ("left", "right"): if "top" in origin: flip_from_gc = True else: flip_from_gc = False if flip_from_gc: screenlow, screenhigh = screenhigh, screenlow if (datalow == datahigh) or (screenlow == screenhigh) or \ (datalow in [inf, -inf]) or (datahigh in [inf, -inf]): self._reset_cache() self._cache_valid = True return if datalow > datahigh: raise RuntimeError("DataRange low is greater than high; unable to compute axis ticks.") if not self.tick_generator: return if hasattr(self.tick_generator, "get_ticks_and_labels"): # generate ticks and labels simultaneously tmp = self.tick_generator.get_ticks_and_labels(datalow, datahigh, screenlow, screenhigh) if len(tmp) == 0: tick_list = [] labels = [] else: tick_list, labels = tmp # compute the labels here self.ticklabel_cache = [Label(text=lab, font=self.tick_label_font, color=self.tick_label_color) \ for lab in labels] self._tick_label_bounding_boxes = [array(ticklabel.get_bounding_box(gc), float64) \ for ticklabel in self.ticklabel_cache] else: scale = 'log' if isinstance(self.mapper, LogMapper) else 'linear' if self.small_haxis_style: tick_list = array([datalow, datahigh]) else: tick_list = array(self.tick_generator.get_ticks(datalow, datahigh, datalow, datahigh, self.tick_interval, use_endpoints=False, scale=scale), float64) mapped_tick_positions = (array(self.mapper.map_screen(tick_list))-screenlow) / \ (screenhigh-screenlow) self._tick_positions = around(array([self._axis_vector*tickpos + self._origin_point \ for tickpos in mapped_tick_positions])) self._tick_label_list = tick_list self._tick_label_positions = self._tick_positions return def _compute_labels(self, gc): """Generates the labels for tick marks. Waits for the cache to become invalid. """ # tick labels are already computed if hasattr(self.tick_generator, "get_ticks_and_labels"): return formatter = self.tick_label_formatter def build_label(val): tickstring = formatter(val) if formatter is not None else str(val) return Label(text=tickstring, font=self.tick_label_font, color=self.tick_label_color, rotate_angle=self.tick_label_rotate_angle, margin=self.tick_label_margin) self.ticklabel_cache = [build_label(val) for val in self._tick_label_list] self._tick_label_bounding_boxes = [array(ticklabel.get_bounding_box(gc), float) for ticklabel in self.ticklabel_cache] return def _calculate_geometry(self): origin = self.origin screenhigh = self.mapper.high_pos screenlow = self.mapper.low_pos if self.orientation in ('top', 'bottom'): self._major_axis_size = self.bounds[0] self._minor_axis_size = self.bounds[1] self._major_axis = array([1., 0.]) self._title_orientation = array([0.,1.]) if self.orientation == 'top': self._origin_point = array(self.position) self._inside_vector = array([0.,-1.]) else: #self.oriention == 'bottom' self._origin_point = array(self.position) + array([0., self.bounds[1]]) self._inside_vector = array([0., 1.]) if "right" in origin: screenlow, screenhigh = screenhigh, screenlow elif self.orientation in ('left', 'right'): self._major_axis_size = self.bounds[1] self._minor_axis_size = self.bounds[0] self._major_axis = array([0., 1.]) self._title_orientation = array([-1., 0]) if self.orientation == 'left': self._origin_point = array(self.position) + array([self.bounds[0], 0.]) self._inside_vector = array([1., 0.]) else: #self.orientation == 'right' self._origin_point = array(self.position) self._inside_vector = array([-1., 0.]) if "top" in origin: screenlow, screenhigh = screenhigh, screenlow if self.ensure_ticks_bounded: self._origin_point -= self._inside_vector*self.tick_in self._end_axis_point = abs(screenhigh-screenlow)*self._major_axis + self._origin_point self._axis_vector = self._end_axis_point - self._origin_point # This is the vector that represents one unit of data space in terms of screen space. self._axis_pixel_vector = self._axis_vector/sqrt(dot(self._axis_vector,self._axis_vector)) return def _calculate_geometry_overlay(self, overlay_component=None): if overlay_component is None: overlay_component = self component_origin = getattr(overlay_component, "origin", 'bottom left') screenhigh = self.mapper.high_pos screenlow = self.mapper.low_pos if self.orientation in ('top', 'bottom'): self._major_axis_size = overlay_component.bounds[0] self._minor_axis_size = overlay_component.bounds[1] self._major_axis = array([1., 0.]) self._title_orientation = array([0.,1.]) if self.orientation == 'top': self._origin_point = array([overlay_component.x, overlay_component.y2]) self._inside_vector = array([0.0, -1.0]) else: self._origin_point = array([overlay_component.x, overlay_component.y]) self._inside_vector = array([0.0, 1.0]) if "right" in component_origin: screenlow, screenhigh = screenhigh, screenlow elif self.orientation in ('left', 'right'): self._major_axis_size = overlay_component.bounds[1] self._minor_axis_size = overlay_component.bounds[0] self._major_axis = array([0., 1.]) self._title_orientation = array([-1., 0]) if self.orientation == 'left': self._origin_point = array([overlay_component.x, overlay_component.y]) self._inside_vector = array([1.0, 0.0]) else: self._origin_point = array([overlay_component.x2, overlay_component.y]) self._inside_vector = array([-1.0, 0.0]) if "top" in component_origin: screenlow, screenhigh = screenhigh, screenlow if self.ensure_ticks_bounded: self._origin_point -= self._inside_vector*self.tick_in self._end_axis_point = abs(screenhigh-screenlow)*self._major_axis + self._origin_point self._axis_vector = self._end_axis_point - self._origin_point # This is the vector that represents one unit of data space in terms of screen space. self._axis_pixel_vector = self._axis_vector/sqrt(dot(self._axis_vector,self._axis_vector)) return #------------------------------------------------------------------------ # Event handlers #------------------------------------------------------------------------ def _bounds_changed(self, old, new): super(PlotAxis, self)._bounds_changed(old, new) self._layout_needed = True self._invalidate() def _bounds_items_changed(self, event): super(PlotAxis, self)._bounds_items_changed(event) self._layout_needed = True self._invalidate() def _mapper_changed(self, old, new): if old is not None: old.on_trait_change(self.mapper_updated, "updated", remove=True) if new is not None: new.on_trait_change(self.mapper_updated, "updated") self._invalidate() def mapper_updated(self): """ Event handler that is bound to this axis's mapper's **updated** event """ self._invalidate() def _position_changed(self, old, new): super(PlotAxis, self)._position_changed(old, new) self._cache_valid = False def _position_items_changed(self, event): super(PlotAxis, self)._position_items_changed(event) self._cache_valid = False def _position_changed_for_component(self): self._cache_valid = False def _position_items_changed_for_component(self): self._cache_valid = False def _bounds_changed_for_component(self): self._cache_valid = False self._layout_needed = True def _bounds_items_changed_for_component(self): self._cache_valid = False self._layout_needed = True def _origin_changed_for_component(self): self._invalidate() def _updated_fired(self): """If the axis bounds changed, redraw.""" self._cache_valid = False return def _invalidate(self): self._cache_valid = False self.invalidate_draw() if self.component: self.component.invalidate_draw() return def _component_changed(self): if self.mapper is not None: # If there is a mapper set, just leave it be. return # Try to pick the most appropriate mapper for our orientation # and what information we can glean from our component. attrmap = { "left": ("ymapper", "y_mapper", "value_mapper"), "bottom": ("xmapper", "x_mapper", "index_mapper"), } attrmap["right"] = attrmap["left"] attrmap["top"] = attrmap["bottom"] component = self.component attr1, attr2, attr3 = attrmap[self.orientation] for attr in attrmap[self.orientation]: if hasattr(component, attr): self.mapper = getattr(component, attr) break # Keep our origin in sync with the component self.origin = getattr(component, 'origin', 'bottom left') return #------------------------------------------------------------------------ # The following event handlers just invalidate our previously computed # Label instances and backbuffer if any of our visual attributes change. # TODO: refactor this stuff and the caching of contained objects (e.g. Label) #------------------------------------------------------------------------ def _title_changed(self): self.invalidate_draw() if self.component: self.component.invalidate_draw() return def _anytrait_changed(self, name, old, new): """ For every trait that defines a visual attribute we just call _invalidate() when a change is made. """ invalidate_traits = [ 'title_font', 'title_spacing', 'title_color', 'title_angle', 'tick_weight', 'tick_color', 'tick_label_font', 'tick_label_color', 'tick_label_rotate_angle', 'tick_label_alignment', 'tick_label_margin', 'tick_label_offset', 'tick_label_position', 'tick_label_formatter', 'tick_in', 'tick_out', 'tick_visible', 'tick_interval', 'tick_generator', 'orientation', 'origin', 'axis_line_visible', 'axis_line_color', 'axis_line_weight', 'axis_line_style', 'small_haxis_style', 'ensure_labels_bounded', 'ensure_ticks_bounded', ] if name in invalidate_traits: self._invalidate() # ------------------------------------------------------------------------ # Initialization-related methods # ------------------------------------------------------------------------ def _title_angle_default(self): if self.orientation == 'left': return 90.0 if self.orientation == 'right': return 270.0 # Then self.orientation in {'top', 'bottom'} return 0.0 #------------------------------------------------------------------------ # Persistence-related methods #------------------------------------------------------------------------ def __getstate__(self): dont_pickle = [ '_tick_list', '_tick_positions', '_tick_label_list', '_tick_label_positions', '_tick_label_bounding_boxes', '_major_axis_size', '_minor_axis_size', '_major_axis', '_title_orientation', '_title_angle', '_origin_point', '_inside_vector', '_axis_vector', '_axis_pixel_vector', '_end_axis_point', '_ticklabel_cache', '_cache_valid' ] state = super(PlotAxis,self).__getstate__() for key in dont_pickle: if key in state: del state[key] return state def __setstate__(self, state): super(PlotAxis,self).__setstate__(state) self._mapper_changed(None, self.mapper) self._reset_cache() self._cache_valid = False return
class RangeSelection(AbstractController): """ Selects a range along the index or value axis. The user right-click-drags to select a region, which stays selected until the user left-clicks to deselect. """ #: The axis to which this tool is perpendicular. axis = Enum("index", "value") #: The selected region, expressed as a tuple in data space. This updates #: and fires change-events as the user is dragging. selection = Property selection_mode = Enum("set", "append") #: This event is fired whenever the user completes the selection, or when a #: finalized selection gets modified. The value of the event is the data #: space range. selection_completed = Event #: The name of the metadata on the datasource that we will write #: self.selection to metadata_name = Str("selections") #: Either "set" or "append", depending on whether self.append_key was #: held down selection_mode_metadata_name = Str("selection_mode") #: The name of the metadata on the datasource that we will set to a numpy #: boolean array for masking the datasource's data mask_metadata_name = Str("selection_masks") #: The possible event states of this selection tool (overrides #: enable.Interactor). #: #: normal: #: Nothing has been selected, and the user is not dragging the mouse. #: selecting: #: The user is dragging the mouse and actively changing the #: selection region; resizing of an existing selection also #: uses this mode. #: selected: #: The user has released the mouse and a selection has been #: finalized. The selection remains until the user left-clicks #: or self.deselect() is called. #: moving: #: The user moving (not resizing) the selection range. event_state = Enum("normal", "selecting", "selected", "moving") #------------------------------------------------------------------------ # Traits for overriding default object relationships # # By default, the RangeSelection assumes that self.component is a plot # and looks for the mapper and the axis_index on it. If this is not the # case, then any (or all) three of these can be overriden by directly # assigning values to them. To unset them and have them revert to default # behavior, assign "None" to them. #------------------------------------------------------------------------ #: The plot associated with this tool By default, this is just #: self.component. plot = Property #: The mapper for associated with this tool. By default, this is the mapper #: on **plot** that corresponds to **axis**. mapper = Property #: The index to use for **axis**. By default, this is self.plot.orientation, #: but it can be overriden and set to 0 or 1. axis_index = Property #: List of listeners that listen to selection events. listeners = List #------------------------------------------------------------------------ # Configuring interaction control #------------------------------------------------------------------------ #: Can the user resize the selection once it has been drawn? enable_resize = Bool(True) #: The pixel distance between the mouse event and a selection endpoint at #: which the user action will be construed as a resize operation. resize_margin = Int(7) #: Allow the left button begin a selection? left_button_selects = Bool(False) #: Disable all left-mouse button interactions? disable_left_mouse = Bool(False) #: Allow the tool to be put into the deselected state via mouse clicks allow_deselection = Bool(True) #: The minimum span, in pixels, of a selection region. Any attempt to #: select a region smaller than this will be treated as a deselection. minimum_selection = Int(5) #: The key which, if held down while the mouse is being dragged, will #: indicate that the selection should be appended to an existing selection #: as opposed to overwriting it. append_key = Instance(KeySpec, args=(None, "control")) #------------------------------------------------------------------------ # Private traits #------------------------------------------------------------------------ # The value of the override plot to use, if any. If None, then uses # self.component. _plot = Trait(None, Any) # The value of the override mapper to use, if any. If None, then uses the # mapper on self.component. _mapper = Trait(None, Any) # Shadow trait for the **axis_index** property. _axis_index = Trait(None, None, Int) # The data space start and end coordinates of the selected region, # expressed as an array. _selection = ArrayOrNone() # The selection in mask form. _selection_mask = Array # The end of the selection that is being actively modified by the mouse. _drag_edge = Enum("high", "low") #------------------------------------------------------------------------ # These record the mouse position when the user is moving (not resizing) # the selection #------------------------------------------------------------------------ # The position of the initial user click for moving the selection. _down_point = Array # (x,y) # The data space coordinates of **_down_point**. _down_data_coord = Float # The original selection when the mouse went down to move the selection. _original_selection = Any #------------------------------------------------------------------------ # Public methods #------------------------------------------------------------------------ def deselect(self, event=None): """ Deselects the highlighted region. This method essentially resets the tool. It takes the event causing the deselection as an optional argument. """ self.selection = None self.selection_completed = None self.event_state = "normal" self.component.request_redraw() if event: event.window.set_pointer("arrow") event.handled = True return #------------------------------------------------------------------------ # Event handlers for the "selected" event state #------------------------------------------------------------------------ def selected_left_down(self, event): """ Handles the left mouse button being pressed when the tool is in the 'selected' state. If the user is allowed to resize the selection, and the event occurred within the resize margin of an endpoint, then the tool switches to the 'selecting' state so that the user can resize the selection. If the event is within the bounds of the selection region, then the tool switches to the 'moving' states. Otherwise, the selection becomes deselected. """ if self.disable_left_mouse: return screen_bounds = self._get_selection_screencoords() if screen_bounds is None: self.deselect(event) return low = min(screen_bounds) high = max(screen_bounds) tmp = (event.x, event.y) ndx = self.axis_index mouse_coord = tmp[ndx] if self.enable_resize: if (abs(mouse_coord - high) <= self.resize_margin) or \ (abs(mouse_coord - low) <= self.resize_margin): return self.selected_right_down(event) if low <= tmp[ndx] <= high: self.event_state = "moving" self._down_point = array([event.x, event.y]) self._down_data_coord = \ self.mapper.map_data(self._down_point)[ndx] self._original_selection = array(self.selection) elif self.allow_deselection: self.deselect(event) else: # Treat this as a combination deselect + left down self.deselect(event) self.normal_left_down(event) event.handled = True return def selected_right_down(self, event): """ Handles the right mouse button being pressed when the tool is in the 'selected' state. If the user is allowed to resize the selection, and the event occurred within the resize margin of an endpoint, then the tool switches to the 'selecting' state so that the user can resize the selection. Otherwise, the selection becomes deselected, and a new selection is started.. """ if self.enable_resize: coords = self._get_selection_screencoords() if coords is not None: start, end = coords tmp = (event.x, event.y) ndx = self.axis_index mouse_coord = tmp[ndx] # We have to do a little swapping; the "end" point # is always what gets updated, so if the user # clicked on the starting point, we have to reverse # the sense of the selection. if abs(mouse_coord - end) <= self.resize_margin: self.event_state = "selecting" self._drag_edge = "high" self.selecting_mouse_move(event) elif abs(mouse_coord - start) <= self.resize_margin: self.event_state = "selecting" self._drag_edge = "low" self.selecting_mouse_move(event) #elif self.allow_deselection: # self.deselect(event) else: # Treat this as a combination deselect + right down self.deselect(event) self.normal_right_down(event) else: # Treat this as a combination deselect + right down self.deselect(event) self.normal_right_down(event) event.handled = True return def selected_mouse_move(self, event): """ Handles the mouse moving when the tool is in the 'selected' srate. If the user is allowed to resize the selection, and the event occurred within the resize margin of an endpoint, then the cursor changes to indicate that the selection could be resized. Otherwise, the cursor is set to an arrow. """ if self.enable_resize: # Change the mouse cursor when the user moves within the # resize margin coords = self._get_selection_screencoords() if coords is not None: start, end = coords tmp = (event.x, event.y) ndx = self.axis_index mouse_coord = tmp[ndx] if abs(mouse_coord - end) <= self.resize_margin or \ abs(mouse_coord - start) <= self.resize_margin: self._set_sizing_cursor(event) return event.window.set_pointer("arrow") event.handled = True return def selected_mouse_leave(self, event): """ Handles the mouse leaving the plot when the tool is in the 'selected' state. Sets the cursor to an arrow. """ event.window.set_pointer("arrow") return #------------------------------------------------------------------------ # Event handlers for the "moving" event state #------------------------------------------------------------------------ def moving_left_up(self, event): """ Handles the left mouse button coming up when the tool is in the 'moving' state. Switches the tool to the 'selected' state. """ if self.disable_left_mouse: return self.event_state = "selected" self.selection_completed = self.selection self._down_point = [] event.handled = True return def moving_mouse_move(self, event): """ Handles the mouse moving when the tool is in the 'moving' state. Moves the selection range by an amount corresponding to the amount that the mouse has moved since its button was pressed. If the new selection range overlaps the endpoints of the data, it is truncated to that endpoint. """ cur_point = array([event.x, event.y]) cur_data_point = self.mapper.map_data(cur_point)[self.axis_index] original_selection = self._original_selection new_selection = original_selection + (cur_data_point - self._down_data_coord) selection_data_width = original_selection[1] - original_selection[0] range = self.mapper.range if min(new_selection) < range.low: new_selection = (range.low, range.low + selection_data_width) elif max(new_selection) > range.high: new_selection = (range.high - selection_data_width, range.high) self.selection = new_selection self.selection_completed = new_selection self.component.request_redraw() event.handled = True return def moving_mouse_leave(self, event): """ Handles the mouse leaving the plot while the tool is in the 'moving' state. If the mouse was within the selection region when it left, the method does nothing. If the mouse was outside the selection region whe it left, the event is treated as moving the selection to the minimum or maximum. """ axis_index = self.axis_index low = self.plot.position[axis_index] high = low + self.plot.bounds[axis_index] - 1 pos = self._get_axis_coord(event) if pos >= low and pos <= high: # the mouse left but was within the mapping range, so don't do # anything return else: # the mouse left and exceeds the mapping range, so we need to slam # the selection all the way to the minimum or the maximum self.moving_mouse_move(event) return def moving_mouse_enter(self, event): if not event.left_down: return self.moving_left_up(event) return #------------------------------------------------------------------------ # Event handlers for the "normal" event state #------------------------------------------------------------------------ def normal_left_down(self, event): """ Handles the left mouse button being pressed when the tool is in the 'normal' state. If the tool allows the left mouse button to start a selection, then it does so. """ if self.left_button_selects: return self.normal_right_down(event) def normal_right_down(self, event): """ Handles the right mouse button being pressed when the tool is in the 'normal' state. Puts the tool into 'selecting' mode, changes the cursor to show that it is selecting, and starts defining the selection. """ pos = self._get_axis_coord(event) mapped_pos = self.mapper.map_data(pos) self.selection = (mapped_pos, mapped_pos) self._set_sizing_cursor(event) self._down_point = array([event.x, event.y]) self.event_state = "selecting" if self.append_key is not None and self.append_key.match(event): self.selection_mode = "append" else: self.selection_mode = "set" self.selecting_mouse_move(event) return #------------------------------------------------------------------------ # Event handlers for the "selecting" event state #------------------------------------------------------------------------ def selecting_mouse_move(self, event): """ Handles the mouse being moved when the tool is in the 'selecting' state. Expands the selection range at the appropriate end, based on the new mouse position. """ if self.selection is not None: axis_index = self.axis_index low = self.plot.position[axis_index] high = low + self.plot.bounds[axis_index] - 1 tmp = self._get_axis_coord(event) if tmp >= low and tmp <= high: new_edge = self.mapper.map_data(self._get_axis_coord(event)) if self._drag_edge == "high": low_val = self.selection[0] if new_edge >= low_val: self.selection = (low_val, new_edge) else: self.selection = (new_edge, low_val) self._drag_edge = "low" else: high_val = self.selection[1] if new_edge <= high_val: self.selection = (new_edge, high_val) else: self.selection = (high_val, new_edge) self._drag_edge = "high" self.component.request_redraw() event.handled = True return def selecting_button_up(self, event): # Check to see if the selection region is bigger than the minimum event.window.set_pointer("arrow") end = self._get_axis_coord(event) if len(self._down_point) == 0: cancel_selection = False else: start = self._down_point[self.axis_index] self._down_point = [] cancel_selection = self.minimum_selection > abs(start - end) if cancel_selection: self.deselect(event) event.handled = True else: self.event_state = "selected" # Fire the "completed" event self.selection_completed = self.selection event.handled = True return def selecting_right_up(self, event): """ Handles the right mouse button coming up when the tool is in the 'selecting' state. Switches the tool to the 'selected' state and completes the selection. """ self.selecting_button_up(event) def selecting_left_up(self, event): """ Handles the left mouse button coming up when the tool is in the 'selecting' state. Switches the tool to the 'selected' state. """ if self.disable_left_mouse: return self.selecting_button_up(event) def selecting_mouse_leave(self, event): """ Handles the mouse leaving the plot when the tool is in the 'selecting' state. Determines whether the event's position is outside the component's bounds, and if so, clips the selection. Sets the cursor to an arrow. """ axis_index = self.axis_index low = self.plot.position[axis_index] high = low + self.plot.bounds[axis_index] - 1 old_selection = self.selection selection_low = old_selection[0] selection_high = old_selection[1] pos = self._get_axis_coord(event) if pos >= high: # clip to the boundary appropriate for the mapper's orientation. if self.mapper.sign == 1: selection_high = self.mapper.map_data(high) else: selection_high = self.mapper.map_data(low) elif pos <= low: if self.mapper.sign == 1: selection_low = self.mapper.map_data(low) else: selection_low = self.mapper.map_data(high) self.selection = (selection_low, selection_high) event.window.set_pointer("arrow") self.component.request_redraw() return def selecting_mouse_enter(self, event): """ Handles the mouse entering the plot when the tool is in the 'selecting' state. If the mouse does not have the right mouse button down, this event is treated as if the right mouse button was released. Otherwise, the method sets the cursor to show that it is selecting. """ # If we were in the "selecting" state when the mouse left, and # the mouse has entered without a button being down, # then treat this like we got a button up event. if not (event.right_down or event.left_down): return self.selecting_button_up(event) else: self._set_sizing_cursor(event) return #------------------------------------------------------------------------ # Property getter/setters #------------------------------------------------------------------------ def _get_plot(self): if self._plot is not None: return self._plot else: return self.component def _set_plot(self, val): self._plot = val return def _get_mapper(self): if self._mapper is not None: return self._mapper else: return getattr(self.plot, self.axis + "_mapper") def _set_mapper(self, new_mapper): self._mapper = new_mapper return def _get_axis_index(self): if self._axis_index is None: return self._determine_axis() else: return self._axis_index def _set_axis_index(self, val): self._axis_index = val return def _get_selection(self): selection = getattr(self.plot, self.axis).metadata[self.metadata_name] return selection def _set_selection(self, val): oldval = self._selection self._selection = val datasource = getattr(self.plot, self.axis, None) if datasource is not None: mdname = self.metadata_name # Set the selection range on the datasource datasource.metadata[mdname] = val datasource.metadata_changed = {mdname: val} # Set the selection mask on the datasource selection_masks = \ datasource.metadata.setdefault(self.mask_metadata_name, []) for index in range(len(selection_masks)): if id(selection_masks[index]) == id(self._selection_mask): del selection_masks[index] break # Set the selection mode on the datasource datasource.metadata[self.selection_mode_metadata_name] = \ self.selection_mode if val is not None: low, high = val data_pts = datasource.get_data() new_mask = (data_pts >= low) & (data_pts <= high) selection_masks.append(new_mask) self._selection_mask = new_mask datasource.metadata_changed = {self.mask_metadata_name: val} self.trait_property_changed("selection", oldval, val) for l in self.listeners: if hasattr(l, "set_value_selection"): l.set_value_selection(val) return #------------------------------------------------------------------------ # Private methods #------------------------------------------------------------------------ def _get_selection_screencoords(self): """ Returns a tuple of (x1, x2) screen space coordinates of the start and end selection points. If there is no current selection, then it returns None. """ selection = self.selection if selection is not None and len(selection) == 2: return self.mapper.map_screen(array(selection)) else: return None def _set_sizing_cursor(self, event): """ Sets the correct cursor shape on the window of the event, given the tool's orientation and axis. """ if self.axis_index == 0: # horizontal range selection, so use left/right arrow event.window.set_pointer("size left") else: # vertical range selection, so use up/down arrow event.window.set_pointer("size top") return def _get_axis_coord(self, event, axis="index"): """ Returns the coordinate of the event along the axis of interest to this tool (or along the orthogonal axis, if axis="value"). """ event_pos = (event.x, event.y) if axis == "index": return event_pos[self.axis_index] else: return event_pos[1 - self.axis_index] def _determine_axis(self): """ Determines whether the index of the coordinate along this tool's axis of interest is the first or second element of an (x,y) coordinate tuple. This method is only called if self._axis_index hasn't been set (or is None). """ if self.axis == "index": if self.plot.orientation == "h": return 0 else: return 1 else: # self.axis == "value" if self.plot.orientation == "h": return 1 else: return 0 def __mapper_changed(self): self.deselect() return def _axis_changed(self, old, new): if old is not None: self.plot.on_trait_change(self.__mapper_changed, old + "_mapper", remove=True) if new is not None: self.plot.on_trait_change(self.__mapper_changed, old + "_mapper", remove=True) return
class DataLabel(ToolTip): """ A label on a point in data space. Optionally, an arrow is drawn to the point. """ # The symbol to use if **marker** is set to "custom". This attribute must # be a compiled path for the given Kiva context. custom_symbol = Any # The point in data space where this label should anchor itself. data_point = ArrayOrNone() # The location of the data label relative to the data point. label_position = LabelPositionTrait # The format string that determines the label's text. This string is # formatted using a dict containing the keys 'x' and 'y', corresponding to # data space values. label_format = Str("(%(x)f, %(y)f)") # The text to show on the label, or above the coordinates for the label, if # show_label_coords is True label_text = Str # Flag whether to show coordinates with the label or not. show_label_coords = Bool(True) # Does the label clip itself against the main plot area? If not, then # the label draws into the padding area (where axes typically reside). clip_to_plot = Bool(True) # The center x position (average of x and x2) xmid = Property(Float, depends_on=['x', 'x2']) # The center y position (average of y and y2) ymid = Property(Float, depends_on=['y', 'y2']) # 'box' is a simple rectangular box, with an arrow that is a single line # with an arrowhead at the data point. # 'bubble' can be given rounded corners (by setting `corner_radius`), and # the 'arrow' is a thin triangular wedge with its point at the data point. # When label_style is 'bubble', the following traits are ignored: # arrow_size, arrow_color, arrow_root, and arrow_max_length. label_style = Enum('box', 'bubble') #---------------------------------------------------------------------- # Marker traits #---------------------------------------------------------------------- # Mark the point on the data that this label refers to? marker_visible = Bool(True) # The type of marker to use. This is a mapped trait using strings as the # keys. marker = MarkerTrait # The pixel size of the marker (doesn't include the thickness of the # outline). marker_size = Int(4) # The thickness, in pixels, of the outline to draw around the marker. # If this is 0, no outline will be drawn. marker_line_width = Float(1.0) # The color of the inside of the marker. marker_color = ColorTrait("red") # The color out of the border drawn around the marker. marker_line_color = ColorTrait("black") #---------------------------------------------------------------------- # Arrow traits #---------------------------------------------------------------------- # Draw an arrow from the label to the data point? Only # used if **data_point** is not None. arrow_visible = Bool(True) # FIXME: replace with some sort of ArrowStyle # The length of the arrowhead, in screen points (e.g., pixels). arrow_size = Float(10) # The color of the arrow. arrow_color = ColorTrait("black") # The position of the base of the arrow on the label. If this # is 'auto', then the label uses **label_position**. Otherwise, it # treats the label as if it were at the label position indicated by # this attribute. arrow_root = Trait("auto", "auto", "top left", "top right", "bottom left", "bottom right", "top center", "bottom center", "left center", "right center") # The minimum length of the arrow before it will be drawn. By default, # the arrow will be drawn regardless of how short it is. arrow_min_length = Float(0) # The maximum length of the arrow before it will be drawn. By default, # the arrow will be drawn regardless of how long it is. arrow_max_length = Float(inf) #---------------------------------------------------------------------- # Bubble traits #---------------------------------------------------------------------- # The radius (in screen coordinates) of the curved corners of the "bubble". corner_radius = Float(10) #------------------------------------------------------------------------- # Private traits #------------------------------------------------------------------------- # Tuple (sx, sy) of the mapped screen coordinates of **data_point**. _screen_coords = Any _cached_arrow = Any # When **arrow_root** is 'auto', this determines the location on the data # label from which the arrow is drawn, based on the position of the label # relative to its data point. _position_root_map = { "top left": "bottom right", "top right": "bottom left", "bottom left": "top right", "bottom right": "top left", "top center": "bottom center", "bottom center": "top center", "left center": "right center", "right center": "left center" } _root_positions = { "bottom right": ("x2", "y"), "bottom left": ("x", "y"), "top right": ("x2", "y2"), "top left": ("x", "y2"), "top center": ("xmid", "y2"), "bottom center": ("xmid", "y"), "left center": ("x", "ymid"), "right center": ("x2", "ymid"), } def overlay(self, component, gc, view_bounds=None, mode="normal"): """ Draws the tooltip overlaid on another component. Overrides and extends ToolTip.overlay() """ if self.clip_to_plot: gc.save_state() c = component gc.clip_to_rect(c.x, c.y, c.width, c.height) self.do_layout() if self.label_style == 'box': self._render_box(component, gc, view_bounds=view_bounds, mode=mode) else: self._render_bubble(component, gc, view_bounds=view_bounds, mode=mode) # draw the marker if self.marker_visible: render_markers(gc, [self._screen_coords], self.marker, self.marker_size, self.marker_color_, self.marker_line_width, self.marker_line_color_, self.custom_symbol) if self.clip_to_plot: gc.restore_state() def _render_box(self, component, gc, view_bounds=None, mode='normal'): # draw the arrow if necessary if self.arrow_visible: if self._cached_arrow is None: if self.arrow_root in self._root_positions: ox, oy = self._root_positions[self.arrow_root] else: if self.arrow_root == "auto": arrow_root = self.label_position else: arrow_root = self.arrow_root pos = self._position_root_map.get(arrow_root, "DUMMY") ox, oy = self._root_positions.get( pos, (self.x + self.width / 2, self.y + self.height / 2)) if type(ox) == str: ox = getattr(self, ox) oy = getattr(self, oy) self._cached_arrow = draw_arrow(gc, (ox, oy), self._screen_coords, self.arrow_color_, arrowhead_size=self.arrow_size, offset1=3, offset2=self.marker_size + 3, minlen=self.arrow_min_length, maxlen=self.arrow_max_length) else: draw_arrow(gc, None, None, self.arrow_color_, arrow=self._cached_arrow, minlen=self.arrow_min_length, maxlen=self.arrow_max_length) # layout and render the label itself ToolTip.overlay(self, component, gc, view_bounds, mode) def _render_bubble(self, component, gc, view_bounds=None, mode='normal'): """ Render the bubble label in the graphics context. """ # (px, py) is the data point in screen space. px, py = self._screen_coords # (x, y) is the lower left corner of the label. x = self.x y = self.y # (x2, y2) is the upper right corner of the label. x2 = self.x2 y2 = self.y2 # r is the corner radius. r = self.corner_radius if self.arrow_visible: # FIXME: Make 'gap_width' a configurable trait (and give it a # better name). max_gap_width = 10 gap_width = min(max_gap_width, abs(x2 - x - 2 * r), abs(y2 - y - 2 * r)) region = find_region(px, py, x, y, x2, y2) # Figure out where the "arrow" connects to the "bubble". if region == 'left' or region == 'right': gap_start = py - gap_width / 2 if gap_start < y + r: gap_start = y + r elif gap_start > y2 - r - gap_width: gap_start = y2 - r - gap_width by = gap_start + 0.5 * gap_width if region == 'left': bx = x else: bx = x2 else: gap_start = px - gap_width / 2 if gap_start < x + r: gap_start = x + r elif gap_start > x2 - r - gap_width: gap_start = x2 - r - gap_width bx = gap_start + 0.5 * gap_width if region == 'top': by = y2 else: by = y arrow_len = sqrt((px - bx)**2 + (py - by)**2) arrow_visible = (self.arrow_visible and (arrow_len >= self.arrow_min_length)) with gc: if self.border_visible: gc.set_line_width(self.border_width) gc.set_stroke_color(self.border_color_) else: gc.set_line_width(0) gc.set_stroke_color((0, 0, 0, 0)) gc.set_fill_color(self.bgcolor_) # Start at the lower left, on the left edge where the curved # part of the box ends. gc.move_to(x, y + r) # Draw the left side and the upper left curved corner. if arrow_visible and region == 'left': gc.line_to(x, gap_start) gc.line_to(px, py) gc.line_to(x, gap_start + gap_width) gc.arc_to(x, y2, x + r, y2, r) # Draw the top and the upper right curved corner. if arrow_visible and region == 'top': gc.line_to(gap_start, y2) gc.line_to(px, py) gc.line_to(gap_start + gap_width, y2) gc.arc_to(x2, y2, x2, y2 - r, r) # Draw the right side and the lower right curved corner. if arrow_visible and region == 'right': gc.line_to(x2, gap_start + gap_width) gc.line_to(px, py) gc.line_to(x2, gap_start) gc.arc_to(x2, y, x2 - r, y, r) # Draw the bottom and the lower left curved corner. if arrow_visible and region == 'bottom': gc.line_to(gap_start + gap_width, y) gc.line_to(px, py) gc.line_to(gap_start, y) gc.arc_to(x, y, x, y + r, r) # Finish the "bubble". gc.draw_path() self._draw_overlay(gc) def _do_layout(self, size=None): """Computes the size and position of the label and arrow. Overrides and extends ToolTip._do_layout() """ if not self.component or not hasattr(self.component, "map_screen"): return # Call the parent class layout. This computes all the label ToolTip._do_layout(self) self._screen_coords = self.component.map_screen([self.data_point])[0] sx, sy = self._screen_coords if isinstance(self.label_position, str): orientation = self.label_position if ("left" in orientation) or ("right" in orientation): if " " not in orientation: self.y = sy - self.height / 2 if "left" in orientation: self.outer_x = sx - self.outer_width - 1 elif "right" in orientation: self.outer_x = sx if ("top" in orientation) or ("bottom" in orientation): if " " not in orientation: self.x = sx - self.width / 2 if "bottom" in orientation: self.outer_y = sy - self.outer_height - 1 elif "top" in orientation: self.outer_y = sy if "center" in orientation: if " " not in orientation: self.x = sx - (self.width / 2) self.y = sy - (self.height / 2) else: self.x = sx - (self.outer_width / 2) - 1 self.y = sy - (self.outer_height / 2) - 1 else: self.x = sx + self.label_position[0] self.y = sy + self.label_position[1] self._cached_arrow = None return def _data_point_changed(self, old, new): if new is not None: self._create_new_labels() def _label_format_changed(self, old, new): self._create_new_labels() def _label_text_changed(self, old, new): self._create_new_labels() def _show_label_coords_changed(self, old, new): self._create_new_labels() def _create_new_labels(self): pt = self.data_point if pt is not None: if self.show_label_coords: self.lines = [ self.label_text, self.label_format % { "x": pt[0], "y": pt[1] } ] else: self.lines = [self.label_text] def _component_changed(self, old, new): for comp, attach in ((old, False), (new, True)): if comp is not None: if hasattr(comp, 'index_mapper'): self._modify_mapper_listeners(comp.index_mapper, attach=attach) if hasattr(comp, 'value_mapper'): self._modify_mapper_listeners(comp.value_mapper, attach=attach) return def _modify_mapper_listeners(self, mapper, attach=True): if mapper is not None: mapper.on_trait_change(self._handle_mapper, 'updated', remove=not attach) return def _handle_mapper(self): # This gets fired whenever a mapper on our plot fires its # 'updated' event. self._layout_needed = True @on_trait_change("arrow_size,arrow_root,arrow_min_length," + "arrow_max_length") def _invalidate_arrow(self): self._cached_arrow = None self._layout_needed = True @on_trait_change("label_position,position,position_items,bounds," + "bounds_items") def _invalidate_layout(self): self._layout_needed = True def _get_xmid(self): return 0.5 * (self.x + self.x2) def _get_ymid(self): return 0.5 * (self.y + self.y2)
array, argsort, concatenate, cos, diff, dot, dtype, empty, float32, isfinite, nonzero, pi, searchsorted, seterr, sin, int8 ) # Enthought library imports from traits.api import Enum, ArrayOrNone delta = {'ascending': 1, 'descending': -1, 'flat': 0} rgba_dtype = dtype([('r', float32), ('g', float32), ('b', float32), ('a', float32)]) point_dtype = dtype([('x', float), ('y', float)]) # Dimensions # A single array of numbers. NumericalSequenceTrait = ArrayOrNone() # A sequence of pairs of numbers, i.e., an Nx2 array. PointTrait = ArrayOrNone(shape=(None, 2)) # An NxM array of numbers or NxMxRGB(A) array of colors. ImageTrait = ArrayOrNone() # An 3D array of numbers of shape (Nx, Ny, Nz) CubeTrait = ArrayOrNone(shape=(None, None, None)) # This enumeration lists the fundamental mathematical coordinate types that # Chaco supports. DimensionTrait = Enum("scalar", "point", "image", "cube")
class PointObject(Object): """Represent a group of individual points in a mayavi scene.""" label = Bool(False) label_scale = Float(0.01) projectable = Bool(False) # set based on type of points orientable = Property(depends_on=['nearest']) text3d = List point_scale = Float(10, label='Point Scale') # projection onto a surface nearest = Instance(_DistanceQuery) check_inside = Instance(_CheckInside) project_to_trans = ArrayOrNone(float, shape=(4, 4)) project_to_surface = Bool(False, label='Project', desc='project points ' 'onto the surface') orient_to_surface = Bool(False, label='Orient', desc='orient points ' 'toward the surface') scale_by_distance = Bool(False, label='Dist.', desc='scale points by ' 'distance from the surface') mark_inside = Bool(False, label='Mark', desc='mark points inside the ' 'surface in a different color') inside_color = RGBColor((0., 0., 0.)) glyph = Instance(Glyph) resolution = Int(8) view = View( HGroup(Item('visible', show_label=False), Item('color', show_label=False), Item('opacity'))) def __init__(self, view='points', has_norm=False, *args, **kwargs): """Init. Parameters ---------- view : 'points' | 'cloud' Whether the view options should be tailored to individual points or a point cloud. has_norm : bool Whether a norm can be defined; adds view options based on point norms (default False). """ assert view in ('points', 'cloud', 'arrow') self._view = view self._has_norm = bool(has_norm) super(PointObject, self).__init__(*args, **kwargs) def default_traits_view(self): # noqa: D102 color = Item('color', show_label=False) scale = Item('point_scale', label='Size', width=_SCALE_WIDTH, editor=laggy_float_editor_headscale) orient = Item('orient_to_surface', enabled_when='orientable and not project_to_surface', tooltip='Orient points toward the surface') dist = Item('scale_by_distance', enabled_when='orientable and not project_to_surface', tooltip='Scale points by distance from the surface') mark = Item('mark_inside', enabled_when='orientable and not project_to_surface', tooltip='Mark points inside the surface using a different ' 'color') if self._view == 'arrow': visible = Item('visible', label='Show', show_label=False) return View(HGroup(visible, scale, 'opacity', 'label', Spring())) elif self._view == 'points': visible = Item('visible', label='Show', show_label=True) views = (visible, color, scale, 'label') else: assert self._view == 'cloud' visible = Item('visible', show_label=False) views = (visible, color, scale) if not self._has_norm: return View(HGroup(*views)) group2 = HGroup(dist, Item('project_to_surface', show_label=True, enabled_when='projectable', tooltip='Project points onto the surface ' '(for visualization, does not affect ' 'fitting)'), orient, mark, Spring(), show_left=False) return View(HGroup(HGroup(*views), group2)) @on_trait_change('label') def _show_labels(self, show): _toggle_mlab_render(self, False) while self.text3d: text = self.text3d.pop() text.remove() if show and len(self.src.data.points) > 0: fig = self.scene.mayavi_scene if self._view == 'arrow': # for axes x, y, z = self.src.data.points[0] self.text3d.append( text3d(x, y, z, self.name, scale=self.label_scale, color=self.color, figure=fig)) else: for i, (x, y, z) in enumerate(np.array(self.src.data.points)): self.text3d.append( text3d(x, y, z, ' %i' % i, scale=self.label_scale, color=self.color, figure=fig)) _toggle_mlab_render(self, True) @on_trait_change('visible') def _on_hide(self): if not self.visible: self.label = False @on_trait_change('scene.activated') def _plot_points(self): """Add the points to the mayavi pipeline""" if self.scene is None: return if hasattr(self.glyph, 'remove'): self.glyph.remove() if hasattr(self.src, 'remove'): self.src.remove() _toggle_mlab_render(self, False) x, y, z = self.points.T fig = self.scene.mayavi_scene scatter = pipeline.scalar_scatter(x, y, z, fig=fig) if not scatter.running: # this can occur sometimes during testing w/ui.dispose() return # fig.scene.engine.current_object is scatter mode = 'arrow' if self._view == 'arrow' else 'sphere' glyph = pipeline.glyph(scatter, color=self.color, figure=fig, scale_factor=self.point_scale, opacity=1., resolution=self.resolution, mode=mode) glyph.actor.property.backface_culling = True glyph.glyph.glyph.vector_mode = 'use_normal' glyph.glyph.glyph.clamping = False if mode == 'arrow': glyph.glyph.glyph_source.glyph_position = 'tail' glyph.actor.mapper.color_mode = 'map_scalars' glyph.actor.mapper.scalar_mode = 'use_point_data' glyph.actor.mapper.use_lookup_table_scalar_range = False self.src = scatter self.glyph = glyph self.sync_trait('point_scale', self.glyph.glyph.glyph, 'scale_factor') self.sync_trait('color', self.glyph.actor.property, mutual=False) self.sync_trait('visible', self.glyph) self.sync_trait('opacity', self.glyph.actor.property) self.sync_trait('mark_inside', self.glyph.actor.mapper, 'scalar_visibility') self.on_trait_change(self._update_points, 'points') self._update_marker_scaling() self._update_marker_type() self._update_colors() _toggle_mlab_render(self, True) # self.scene.camera.parallel_scale = _scale def _nearest_default(self): return _DistanceQuery(np.zeros((1, 3))) def _get_nearest(self, proj_rr): idx = self.nearest.query(proj_rr)[1] proj_pts = apply_trans(self.project_to_trans, self.nearest.data[idx]) proj_nn = apply_trans(self.project_to_trans, self.check_inside.surf['nn'][idx], move=False) return proj_pts, proj_nn @on_trait_change('points,project_to_trans,project_to_surface,mark_inside,' 'nearest') def _update_projections(self): """Update the styles of the plotted points.""" if not hasattr(self.src, 'data'): return if self._view == 'arrow': self.src.data.point_data.normals = self.nn self.src.data.point_data.update() return # projections if len(self.nearest.data) <= 1 or len(self.points) == 0: return # Do the projections pts = self.points inv_trans = np.linalg.inv(self.project_to_trans) proj_rr = apply_trans(inv_trans, self.points) proj_pts, proj_nn = self._get_nearest(proj_rr) vec = pts - proj_pts # point to the surface if self.project_to_surface: pts = proj_pts nn = proj_nn if self.mark_inside and not self.project_to_surface: scalars = (~self.check_inside(proj_rr, verbose=False)).astype(int) else: scalars = np.ones(len(pts)) # With this, a point exactly on the surface is of size point_scale dist = np.linalg.norm(vec, axis=-1, keepdims=True) self.src.data.point_data.normals = (250 * dist + 1) * nn self.src.data.point_data.scalars = scalars self.glyph.actor.mapper.scalar_range = [0., 1.] self.src.data.points = pts # projection can change this self.src.data.point_data.update() @on_trait_change('color,inside_color') def _update_colors(self): if self.glyph is None: return # inside_color is the surface color, let's try to get far # from that inside = np.array(self.inside_color) # if it's too close to gray, just use black: if np.mean(np.abs(inside - 0.5)) < 0.2: inside.fill(0.) else: inside = 1 - inside colors = np.array([tuple(inside) + (1, ), tuple(self.color) + (1, )]) * 255. self.glyph.module_manager.scalar_lut_manager.lut.table = colors @on_trait_change('project_to_surface,orient_to_surface') def _update_marker_type(self): # not implemented for arrow if self.glyph is None or self._view == 'arrow': return defaults = DEFAULTS['coreg'] gs = self.glyph.glyph.glyph_source res = getattr(gs.glyph_source, 'theta_resolution', getattr(gs.glyph_source, 'resolution', None)) if self.project_to_surface or self.orient_to_surface: gs.glyph_source = tvtk.CylinderSource() gs.glyph_source.height = defaults['eegp_height'] gs.glyph_source.center = (0., -defaults['eegp_height'], 0) gs.glyph_source.resolution = res else: gs.glyph_source = tvtk.SphereSource() gs.glyph_source.phi_resolution = res gs.glyph_source.theta_resolution = res @on_trait_change('scale_by_distance,project_to_surface') def _update_marker_scaling(self): if self.glyph is None: return if self.scale_by_distance and not self.project_to_surface: self.glyph.glyph.scale_mode = 'scale_by_vector' else: self.glyph.glyph.scale_mode = 'data_scaling_off' def _resolution_changed(self, new): if not self.glyph: return gs = self.glyph.glyph.glyph_source.glyph_source if isinstance(gs, tvtk.SphereSource): gs.phi_resolution = new gs.theta_resolution = new elif isinstance(gs, tvtk.CylinderSource): gs.resolution = new else: # ArrowSource gs.tip_resolution = new gs.shaft_resolution = new @cached_property def _get_orientable(self): return len(self.nearest.data) > 1
class Legend(AbstractOverlay): """ A legend for a plot. """ # The font to use for the legend text. font = KivaFont("modern 12") # The amount of space between the content of the legend and the border. border_padding = Int(10) # The border is visible (overrides Enable Component). border_visible = True # The color of the text labels color = black_color_trait # The background color of the legend (overrides AbstractOverlay). bgcolor = white_color_trait # The position of the legend with respect to its overlaid component. (This # attribute applies only if the legend is used as an overlay.) # # * ur = Upper Right # * ul = Upper Left # * ll = Lower Left # * lr = Lower Right align = Enum("ur", "ul", "ll", "lr") # The amount of space between legend items. line_spacing = Int(3) # The size of the icon or marker area drawn next to the label. icon_bounds = List([24, 24]) # Amount of spacing between each label and its icon. icon_spacing = Int(5) # Map of labels (strings) to plot instances or lists of plot instances. The # Legend determines the appropriate rendering of each plot's marker/line. plots = Dict # The list of labels to show and the order to show them in. If this # list is blank, then the keys of self.plots is used and displayed in # alphabetical order. Otherwise, only the items in the **labels** # list are drawn in the legend. Labels are ordered from top to bottom. labels = List # Whether or not to hide plots that are not visible. (This is checked during # layout.) This option *will* filter out the items in **labels** above, so # if you absolutely, positively want to set the items that will always # display in the legend, regardless of anything else, then you should turn # this option off. Otherwise, it usually makes sense that a plot renderer # that is not visible will also not be in the legend. hide_invisible_plots = Bool(True) # If hide_invisible_plots is False, we can still choose to render the names # of invisible plots with an alpha. invisible_plot_alpha = Float(0.33) # The renderer that draws the icons for the legend. composite_icon_renderer = Instance(AbstractCompositeIconRenderer) # Action that the legend takes when it encounters a plot whose icon it # cannot render: # # * 'skip': skip it altogether and don't render its name # * 'blank': render the name but leave the icon blank (color=self.bgcolor) # * 'questionmark': render a "question mark" icon error_icon = Enum("skip", "blank", "questionmark") # Should the legend clip to the bounds it needs, or to its parent? clip_to_component = Bool(False) # The legend is not resizable (overrides PlotComponent). resizable = "hv" # An optional title string to show on the legend. title = Str('') # If True, title is at top, if False then at bottom. title_at_top = Bool(True) # The legend draws itself as in one pass when its parent is drawing # the **draw_layer** (overrides PlotComponent). unified_draw = True # The legend is drawn on the overlay layer of its parent (overrides # PlotComponent). draw_layer = "overlay" #------------------------------------------------------------------------ # Private Traits #------------------------------------------------------------------------ # A cached list of Label instances _cached_labels = List # A cached array of label sizes. _cached_label_sizes = ArrayOrNone() # A cached list of label names. _cached_label_names = CList # A list of the visible plots. Each plot corresponds to the label at # the same index in _cached_label_names. This list does not necessarily # correspond to self.plots.value() because it is sorted according to # the plot name and it potentially excludes invisible plots. _cached_visible_plots = CList # A cached array of label positions relative to the legend's origin _cached_label_positions = ArrayOrNone() def is_in(self, x, y): """ overloads from parent class because legend alignment and padding does not cooperatate with the basic implementation This may just be caused byt a questionable implementation of the legend tool, but it works by adjusting the padding. The Component class implementation of is_in uses the outer positions which includes the padding """ in_x = (x >= self.x) and (x <= self.x + self.width) in_y = (y >= self.y) and (y <= self.y + self.height) return in_x and in_y def overlay(self, component, gc, view_bounds=None, mode="normal"): """ Draws this component overlaid on another component. Implements AbstractOverlay. """ self.do_layout() valign, halign = self.align if valign == "u": y = component.y2 - self.outer_height else: y = component.y if halign == "r": x = component.x2 - self.outer_width else: x = component.x self.outer_position = [x, y] if self.clip_to_component: c = self.component with gc: gc.clip_to_rect(c.x, c.y, c.width, c.height) PlotComponent._draw(self, gc, view_bounds, mode) else: PlotComponent._draw(self, gc, view_bounds, mode) return # The following two methods implement the functionality of the Legend # to act as a first-class component instead of merely as an overlay. # The make the Legend use the normal PlotComponent render methods when # it does not have a .component attribute, so that it can have its own # overlays (e.g. a PlotLabel). # # The core legend rendering method is named _draw_as_overlay() so that # it can be called from _draw_plot() when the Legend is not an overlay, # and from _draw_overlay() when the Legend is an overlay. def _draw_plot(self, gc, view_bounds=None, mode="normal"): if self.component is None: self._draw_as_overlay(gc, view_bounds, mode) return def _draw_overlay(self, gc, view_bounds=None, mode="normal"): if self.component is not None: self._draw_as_overlay(gc, view_bounds, mode) else: PlotComponent._draw_overlay(self, gc, view_bounds, mode) return def _draw_as_overlay(self, gc, view_bounds=None, mode="normal"): """ Draws the overlay layer of a component. Overrides PlotComponent. """ # Determine the position we are going to draw at from our alignment # corner and the corresponding outer_padding parameters. (Position # refers to the lower-left corner of our border.) # First draw the border, if necesssary. This sort of duplicates # the code in PlotComponent._draw_overlay, which is unfortunate; # on the other hand, overlays of overlays seem like a rather obscure # feature. with gc: gc.clip_to_rect(int(self.x), int(self.y), int(self.width), int(self.height)) edge_space = self.border_width + self.border_padding icon_width, icon_height = self.icon_bounds icon_x = self.x + edge_space text_x = icon_x + icon_width + self.icon_spacing y = self.y2 - edge_space if self._cached_label_positions is not None: if len(self._cached_label_positions) > 0: self._cached_label_positions[:, 0] = icon_x for i, label_name in enumerate(self._cached_label_names): # Compute the current label's position label_height = self._cached_label_sizes[i][1] y -= label_height self._cached_label_positions[i][1] = y # Try to render the icon icon_y = y + (label_height - icon_height) / 2 #plots = self.plots[label_name] plots = self._cached_visible_plots[i] render_args = (gc, icon_x, icon_y, icon_width, icon_height) try: if isinstance(plots, list) or isinstance(plots, tuple): # TODO: How do we determine if a *group* of plots is # visible or not? For now, just look at the first one # and assume that applies to all of them if not plots[0].visible: # TODO: the get_alpha() method isn't supported on the Mac kiva backend #old_alpha = gc.get_alpha() old_alpha = 1.0 gc.set_alpha(self.invisible_plot_alpha) else: old_alpha = None if len(plots) == 1: plots[0]._render_icon(*render_args) else: self.composite_icon_renderer.render_icon( plots, *render_args) elif plots is not None: # Single plot if not plots.visible: #old_alpha = gc.get_alpha() old_alpha = 1.0 gc.set_alpha(self.invisible_plot_alpha) else: old_alpha = None plots._render_icon(*render_args) else: old_alpha = None # Or maybe 1.0? icon_drawn = True except: icon_drawn = self._render_error(*render_args) if icon_drawn: # Render the text gc.translate_ctm(text_x, y) gc.set_antialias(0) self._cached_labels[i].draw(gc) gc.set_antialias(1) gc.translate_ctm(-text_x, -y) # Advance y to the next label's baseline y -= self.line_spacing if old_alpha is not None: gc.set_alpha(old_alpha) return def _render_error(self, gc, icon_x, icon_y, icon_width, icon_height): """ Renders an error icon or performs some other action when a plot is unable to render its icon. Returns True if something was actually drawn (and hence the legend needs to advance the line) or False if nothing was drawn. """ if self.error_icon == "skip": return False elif self.error_icon == "blank" or self.error_icon == "questionmark": with gc: gc.set_fill_color(self.bgcolor_) gc.rect(icon_x, icon_y, icon_width, icon_height) gc.fill_path() return True else: return False def get_preferred_size(self): """ Computes the size and position of the legend based on the maximum size of the labels, the alignment, and position of the component to overlay. """ # Gather the names of all the labels we will create if len(self.plots) == 0: return [0, 0] plot_names, visible_plots = list( sm.map(list, sm.zip(*sorted(self.plots.items())))) label_names = self.labels if len(label_names) == 0: if len(self.plots) > 0: label_names = plot_names else: self._cached_labels = [] self._cached_label_sizes = [] self._cached_label_names = [] self._cached_visible_plots = [] self.outer_bounds = [0, 0] return [0, 0] if self.hide_invisible_plots: visible_labels = [] visible_plots = [] for name in label_names: # If the user set self.labels, there might be a bad value, # so ensure that each name is actually in the plots dict. if name in self.plots: val = self.plots[name] # Rather than checking for a list/TraitListObject/etc., we just check # for the attribute first if hasattr(val, 'visible'): if val.visible: visible_labels.append(name) visible_plots.append(val) else: # If we have a list of renderers, add the name if any of them are # visible for renderer in val: if renderer.visible: visible_labels.append(name) visible_plots.append(val) break label_names = visible_labels # Create the labels labels = [self._create_label(text) for text in label_names] # For the legend title if self.title_at_top: labels.insert(0, self._create_label(self.title)) label_names.insert(0, 'Legend Label') visible_plots.insert(0, None) else: labels.append(self._create_label(self.title)) label_names.append(self.title) visible_plots.append(None) # We need a dummy GC in order to get font metrics dummy_gc = font_metrics_provider() label_sizes = array( [label.get_width_height(dummy_gc) for label in labels]) if len(label_sizes) > 0: max_label_width = max(label_sizes[:, 0]) total_label_height = sum( label_sizes[:, 1]) + (len(label_sizes) - 1) * self.line_spacing else: max_label_width = 0 total_label_height = 0 legend_width = max_label_width + self.icon_spacing + self.icon_bounds[0] \ + self.hpadding + 2*self.border_padding legend_height = total_label_height + self.vpadding + 2 * self.border_padding self._cached_labels = labels self._cached_label_sizes = label_sizes self._cached_label_positions = zeros_like(label_sizes) self._cached_label_names = label_names self._cached_visible_plots = visible_plots if "h" not in self.resizable: legend_width = self.outer_width if "v" not in self.resizable: legend_height = self.outer_height return [legend_width, legend_height] def get_label_at(self, x, y): """ Returns the label object at (x,y) """ for i, pos in enumerate(self._cached_label_positions): size = self._cached_label_sizes[i] corner = pos + size if (pos[0] <= x <= corner[0]) and (pos[1] <= y <= corner[1]): return self._cached_labels[i] else: return None def _do_layout(self): if self.component is not None or len(self._cached_labels) == 0 or \ self._cached_label_sizes is None or len(self._cached_label_names) == 0: width, height = self.get_preferred_size() self.outer_bounds = [width, height] return def _create_label(self, text): """ Returns a new Label instance for the given text. Subclasses can override this method to customize the creation of labels. """ return Label(text=text, font=self.font, margin=0, color=self.color_, bgcolor="transparent", border_width=0) def _composite_icon_renderer_default(self): return CompositeIconRenderer() #-- trait handlers -------------------------------------------------------- def _anytrait_changed(self, name, old, new): if name in ("font", "border_padding", "padding", "line_spacing", "icon_bounds", "icon_spacing", "labels", "plots", "plots_items", "labels_items", "border_width", "align", "position", "position_items", "bounds", "bounds_items", "label_at_top"): self._layout_needed = True if name == "color": self.get_preferred_size() return def _plots_changed(self): """ Invalidate the caches. """ self._cached_labels = [] self._cached_label_sizes = None self._cached_label_names = [] self._cached_visible_plots = [] self._cached_label_positions = None def _title_at_top_changed(self, old, new): """ Trait handler for when self.title_at_top changes. """ if old == True: indx = 0 else: indx = -1 if old != None: self._cached_labels.pop(indx) self._cached_label_names.pop(indx) self._cached_visible_plots.pop(indx) # For the legend title if self.title_at_top: self._cached_labels.insert(0, self._create_label(self.title)) self._cached_label_names.insert(0, '__legend_label__') self._cached_visible_plots.insert(0, None) else: self._cached_labels.append(self._create_label(self.title)) self._cached_label_names.append(self.title) self._cached_visible_plots.append(None)