def __init__(self, ax, x, y, pickradius=5, which_button=1, **kwargs): """ Create the scatter plot and selection machinery. Parameters ---------- ax : Axes The Axes on which to make the scatter plot x, y : float or array-like, shape (n, ) The data positions. pickradius : float Pick radius, in points. which_button : int, default: 1 Where 1=left, 2=middle, 3=right Other Parameters ---------------- **kwargs : arguments to scatter Other keyword arguments are passed directly to the ``ax.scatter`` command """ super().__init__(ax) self.scatter = ax.scatter(x, y, **kwargs, picker=True) self.scatter.set_pickradius(pickradius) self._observers = CallbackRegistry() self._x = x self._y = y self._button = which_button self.connect_event("pick_event", self._on_pick) self._init_val()
def init_callbacks(self): '''Creates the object's callback registry and default callbacks.''' from spectral import settings from matplotlib.cbook import CallbackRegistry self.callbacks = CallbackRegistry() # callbacks_common may have been set to a shared external registry # (e.g., to the callbacks_common member of another ImageView object). So # don't create it if it has already been set. if self.callbacks_common is None: self.callbacks_common = CallbackRegistry() # Keyboard callback self.cb_mouse = ImageViewMouseHandler(self) self.cb_mouse.connect() # Mouse callback self.cb_keyboard = ImageViewKeyboardHandler(self) self.cb_keyboard.connect() # Class update event callback def updater(*args, **kwargs): if self.classes is None: self.set_classes(args[0].classes) self.refresh() callback = MplCallback(registry=self.callbacks_common, event='spy_classes_modified', callback=updater) callback.connect() self.cb_classes_modified = callback if settings.imshow_enable_rectangle_selector is False: return try: from matplotlib.widgets import RectangleSelector self.selector = RectangleSelector(self.axes, self._select_rectangle, button=1, useblit=True, spancoords='data', drawtype='box', rectprops = \ self.selector_rectprops) self.selector.set_active(False) except: self.selector = None msg = 'Failed to create RectangleSelector object. Interactive ' \ 'pixel class labeling will be unavailable.' warn(msg)
class ImageView(object): '''Class to manage events and data associated with image raster views. In most cases, it is more convenient to simply call :func:`~spectral.graphics.spypylab.imshow`, which creates, displays, and returns an :class:`ImageView` object. Creating an :class:`ImageView` object directly (or creating an instance of a subclass) enables additional customization of the image display (e.g., overriding default event handlers). If the object is created directly, call the :meth:`show` method to display the image. The underlying image display functionality is implemented via :func:`matplotlib.pyplot.imshow`. ''' selector_rectprops = dict(facecolor='red', edgecolor='black', alpha=0.5, fill=True) selector_lineprops = dict(color='black', linestyle='-', linewidth=2, alpha=0.5) def __init__(self, data=None, bands=None, classes=None, source=None, **kwargs): ''' Arguments: `data` (ndarray or :class:`SpyFile`): The source of RGB bands to be displayed. with shape (R, C, B). If the shape is (R, C, 3), the last dimension is assumed to provide the red, green, and blue bands (unless the `bands` argument is provided). If :math:`B > 3` and `bands` is not provided, the first, middle, and last band will be used. `bands` (triplet of integers): Specifies which bands in `data` should be displayed as red, green, and blue, respectively. `classes` (ndarray of integers): An array of integer-valued class labels with shape (R, C). If the `data` argument is provided, the shape must match the first two dimensions of `data`. `source` (ndarray or :class:`SpyFile`): The source of spectral data associated with the image display. This optional argument is used to access spectral data (e.g., to generate a spectrum plot when a user double-clicks on the image display. Keyword arguments: Any keyword that can be provided to :func:`~spectral.graphics.graphics.get_rgb` or :func:`matplotlib.imshow`. ''' import spectral from spectral import settings self.is_shown = False self.imshow_data_kwargs = {'cmap': settings.imshow_float_cmap} self.rgb_kwargs = {} self.imshow_class_kwargs = {'zorder': 1} self.data = data self.data_rgb = None self.data_rgb_meta = {} self.classes = None self.class_rgb = None self.source = None self.bands = bands self.data_axes = None self.class_axes = None self.axes = None self._image_shape = None self.display_mode = None self._interpolation = None self.selection = None self.interpolation = kwargs.get('interpolation', settings.imshow_interpolation) if data is not None: self.set_data(data, bands, **kwargs) if classes is not None: self.set_classes(classes, **kwargs) if source is not None: self.set_source(source) self.class_colors = spectral.spy_colors self.spectrum_plot_fig_id = None self.parent = None self.selector = None self._on_parent_click_cid = None self._class_alpha = settings.imshow_class_alpha # Callbacks for events associated specifically with this window. self.callbacks = None # A sharable callback registry for related windows. If this # CallbackRegistry is set prior to calling ImageView.show (e.g., by # setting it equal to the `callbacks_common` member of another # ImageView object), then the registry will be shared. Otherwise, a new # callback registry will be created for this ImageView. self.callbacks_common = None check_disable_mpl_callbacks() def set_data(self, data, bands=None, **kwargs): '''Sets the data to be shown in the RGB channels. Arguments: `data` (ndarray or SpyImage): If `data` has more than 3 bands, the `bands` argument can be used to specify which 3 bands to display. `data` will be passed to `get_rgb` prior to display. `bands` (3-tuple of int): Indices of the 3 bands to display from `data`. Keyword Arguments: Any valid keyword for `get_rgb` or `matplotlib.imshow` can be given. ''' from .graphics import _get_rgb_kwargs self.data = data self.bands = bands rgb_kwargs = {} for k in _get_rgb_kwargs: if k in kwargs: rgb_kwargs[k] = kwargs.pop(k) self.set_rgb_options(**rgb_kwargs) self._update_data_rgb() if self._image_shape is None: self._image_shape = data.shape[:2] elif data.shape[:2] != self._image_shape: raise ValueError('Image shape is inconsistent with previously ' \ 'set data.') self.imshow_data_kwargs.update(kwargs) if 'interpolation' in self.imshow_data_kwargs: self.interpolation = self.imshow_data_kwargs['interpolation'] self.imshow_data_kwargs.pop('interpolation') if len(kwargs) > 0 and self.is_shown: msg = 'Keyword args to set_data only have an effect if ' \ 'given before the image is shown.' warnings.warn(UserWarning(msg)) if self.is_shown: self.refresh() def set_rgb_options(self, **kwargs): '''Sets parameters affecting RGB display of data. Accepts any keyword supported by :func:`~spectral.graphics.graphics.get_rgb`. ''' from .graphics import _get_rgb_kwargs for k in kwargs: if k not in _get_rgb_kwargs: raise ValueError('Unexpected keyword: {0}'.format(k)) self.rgb_kwargs = kwargs.copy() if self.is_shown: self._update_data_rgb() self.refresh() def _update_data_rgb(self): '''Regenerates the RGB values for display.''' from .graphics import get_rgb_meta (self.data_rgb, self.data_rgb_meta) = \ get_rgb_meta(self.data, self.bands, **self.rgb_kwargs) # If it is a gray-scale image, only keep the first RGB component so # matplotlib imshow's cmap can still be used. if self.data_rgb_meta['mode'] == 'monochrome' and \ self.data_rgb.ndim ==3: (self.bands is not None and len(self.bands) == 1) def set_classes(self, classes, colors=None, **kwargs): '''Sets the array of class values associated with the image data. Arguments: `classes` (ndarray of int): `classes` must be an integer-valued array with the same number rows and columns as the display data (if set). `colors`: (array or 3-tuples): Color triplets (with values in the range [0, 255]) that define the colors to be associatd with the integer indices in `classes`. Keyword Arguments: Any valid keyword for `matplotlib.imshow` can be provided. ''' from .graphics import _get_rgb_kwargs self.classes = classes if classes is None: return if self._image_shape is None: self._image_shape = classes.shape[:2] elif classes.shape[:2] != self._image_shape: raise ValueError('Class data shape is inconsistent with ' \ 'previously set data.') if colors is not None: self.class_colors = colors kwargs = dict([item for item in list(kwargs.items()) if item[0] not in \ _get_rgb_kwargs]) self.imshow_class_kwargs.update(kwargs) if 'interpolation' in self.imshow_class_kwargs: self.interpolation = self.imshow_class_kwargs['interpolation'] self.imshow_class_kwargs.pop('interpolation') if len(kwargs) > 0 and self.is_shown: msg = 'Keyword args to set_classes only have an effect if ' \ 'given before the image is shown.' warnings.warn(UserWarning(msg)) if self.is_shown: self.refresh() def set_source(self, source): '''Sets the image data source (used for accessing spectral data). Arguments: `source` (ndarray or :class:`SpyFile`): The source for spectral data associated with the view. ''' self.source = source def show(self, mode=None, fignum=None): '''Renders the image data. Arguments: `mode` (str): Must be one of: "data": Show the data RGB "classes": Shows indexed color for `classes` "overlay": Shows class colors overlaid on data RGB. If `mode` is not provided, a mode will be automatically selected, based on the data set in the ImageView. `fignum` (int): Figure number of the matplotlib figure in which to display the ImageView. If not provided, a new figure will be created. ''' import matplotlib.pyplot as plt from spectral import settings if self.is_shown: msg = 'ImageView.show should only be called once.' warnings.warn(UserWarning(msg)) return set_mpl_interactive() kwargs = {} if fignum is not None: kwargs['num'] = fignum if settings.imshow_figure_size is not None: kwargs['figsize'] = settings.imshow_figure_size plt.figure(**kwargs) if self.data_rgb is not None: self.show_data() if self.classes is not None: self.show_classes() if mode is None: self._guess_mode() else: self.set_display_mode(mode) self.axes.format_coord = self.format_coord self.init_callbacks() self.is_shown = True def init_callbacks(self): '''Creates the object's callback registry and default callbacks.''' from spectral import settings from matplotlib.cbook import CallbackRegistry self.callbacks = CallbackRegistry() # callbacks_common may have been set to a shared external registry # (e.g., to the callbacks_common member of another ImageView object). So # don't create it if it has already been set. if self.callbacks_common is None: self.callbacks_common = CallbackRegistry() # Keyboard callback self.cb_mouse = ImageViewMouseHandler(self) self.cb_mouse.connect() # Mouse callback self.cb_keyboard = ImageViewKeyboardHandler(self) self.cb_keyboard.connect() # Class update event callback def updater(*args, **kwargs): if self.classes is None: self.set_classes(args[0].classes) self.refresh() callback = MplCallback(registry=self.callbacks_common, event='spy_classes_modified', callback=updater) callback.connect() self.cb_classes_modified = callback if settings.imshow_enable_rectangle_selector is False: return try: from matplotlib.widgets import RectangleSelector self.selector = RectangleSelector(self.axes, self._select_rectangle, button=1, useblit=True, spancoords='data', drawtype='box', rectprops = \ self.selector_rectprops) self.selector.set_active(False) except: self.selector = None msg = 'Failed to create RectangleSelector object. Interactive ' \ 'pixel class labeling will be unavailable.' warn(msg) def label_region(self, rectangle, class_id): '''Assigns all pixels in the rectangle to the specified class. Arguments: `rectangle` (4-tuple of integers): Tuple or list defining the rectangle bounds. Should have the form (row_start, row_stop, col_start, col_stop), where the stop indices are not included (i.e., the effect is `classes[row_start:row_stop, col_start:col_stop] = id`. class_id (integer >= 0): The class to which pixels will be assigned. Returns the number of pixels reassigned (the number of pixels in the rectangle whose class has *changed* to `class_id`. ''' if self.classes is None: self.classes = np.zeros(self.data_rgb.shape[:2], dtype=np.int16) r = rectangle n = np.sum(self.classes[r[0]:r[1], r[2]:r[3]] != class_id) if n > 0: self.classes[r[0]:r[1], r[2]:r[3]] = class_id event = SpyMplEvent('spy_classes_modified') event.classes = self.classes event.nchanged = n self.callbacks_common.process('spy_classes_modified', event) # Make selection rectangle go away. self.selector.to_draw.set_visible(False) self.refresh() return n return 0 def _select_rectangle(self, event1, event2): if event1.inaxes is not self.axes or event2.inaxes is not self.axes: self.selection = None return (r1, c1) = xy_to_rowcol(event1.xdata, event1.ydata) (r2, c2) = xy_to_rowcol(event2.xdata, event2.ydata) (r1, r2) = sorted([r1, r2]) (c1, c2) = sorted([c1, c2]) if (r2 < 0) or (r1 >= self._image_shape[0]) or \ (c2 < 0) or (c1 >= self._image_shape[1]): self.selection = None return r1 = max(r1, 0) r2 = min(r2, self._image_shape[0] - 1) c1 = max(c1, 0) c2 = min(c2, self._image_shape[1] - 1) print('Selected region: [%d: %d, %d: %d]' % (r1, r2 + 1, c1, c2 + 1)) self.selection = [r1, r2 + 1, c1, c2 + 1] self.selector.set_active(False) # Make the rectangle display until at least the next event self.selector.to_draw.set_visible(True) self.selector.update() def _guess_mode(self): '''Select an appropriate display mode, based on current data.''' if self.data_rgb is not None: self.set_display_mode('data') elif self.classes is not None: self.set_display_mode('classes') else: raise Exception('Unable to display image: no data set.') def show_data(self): '''Show the image data.''' import matplotlib.pyplot as plt if self.data_axes is not None: msg = 'ImageView.show_data should only be called once.' warnings.warn(UserWarning(msg)) return elif self.data_rgb is None: raise Exception('Unable to display data: data array not set.') if self.axes is not None: # A figure has already been created for the view. Make it current. plt.figure(self.axes.figure.number) self.imshow_data_kwargs['interpolation'] = self._interpolation self.data_axes = plt.imshow(self.data_rgb, **self.imshow_data_kwargs) if self.axes is None: self.axes = self.data_axes.axes def show_classes(self): '''Show the class values.''' import matplotlib.pyplot as plt from matplotlib.colors import ListedColormap, NoNorm from spectral import get_rgb if self.class_axes is not None: msg = 'ImageView.show_classes should only be called once.' warnings.warn(UserWarning(msg)) return elif self.classes is None: raise Exception('Unable to display classes: class array not set.') cm = ListedColormap(np.array(self.class_colors) / 255.) self._update_class_rgb() kwargs = self.imshow_class_kwargs.copy() kwargs.update({ 'cmap': cm, 'vmin': 0, 'norm': NoNorm(), 'interpolation': self._interpolation }) if self.axes is not None: # A figure has already been created for the view. Make it current. plt.figure(self.axes.figure.number) self.class_axes = plt.imshow(self.class_rgb, **kwargs) if self.axes is None: self.axes = self.class_axes.axes self.class_axes.set_zorder(1) if self.display_mode == 'overlay': self.class_axes.set_alpha(self._class_alpha) else: self.class_axes.set_alpha(1) #self.class_axes.axes.set_axis_bgcolor('black') def refresh(self): '''Updates the displayed data (if it has been shown).''' if self.is_shown: self._update_class_rgb() if self.class_axes is not None: self.class_axes.set_data(self.class_rgb) self.class_axes.set_interpolation(self._interpolation) elif self.display_mode in ('classes', 'overlay'): self.show_classes() if self.data_axes is not None: self.data_axes.set_data(self.data_rgb) self.data_axes.set_interpolation(self._interpolation) elif self.display_mode in ('data', 'overlay'): self.show_data() self.axes.figure.canvas.draw() def _update_class_rgb(self): if self.display_mode == 'overlay': self.class_rgb = np.ma.array(self.classes, mask=(self.classes == 0)) else: self.class_rgb = np.array(self.classes) def set_display_mode(self, mode): '''`mode` must be one of ("data", "classes", "overlay").''' if mode not in ('data', 'classes', 'overlay'): raise ValueError('Invalid display mode: ' + repr(mode)) self.display_mode = mode show_data = mode in ('data', 'overlay') if self.data_axes is not None: self.data_axes.set_visible(show_data) show_classes = mode in ('classes', 'overlay') if self.classes is not None and self.class_axes is None: # Class data values were just set self.show_classes() if self.class_axes is not None: self.class_axes.set_visible(show_classes) if mode == 'classes': self.class_axes.set_alpha(1) else: self.class_axes.set_alpha(self._class_alpha) self.refresh() @property def class_alpha(self): '''alpha transparency for the class overlay.''' return self._class_alpha @class_alpha.setter def class_alpha(self, alpha): if alpha < 0 or alpha > 1: raise ValueError('Alpha value must be in range [0, 1].') self._class_alpha = alpha if self.class_axes is not None: self.class_axes.set_alpha(alpha) if self.is_shown: self.refresh() @property def interpolation(self): '''matplotlib pixel interpolation to use in the image display.''' return self._interpolation @interpolation.setter def interpolation(self, interpolation): if interpolation == self._interpolation: return self._interpolation = interpolation if not self.is_shown: return if self.data_axes is not None: self.data_axes.set_interpolation(interpolation) if self.class_axes is not None: self.class_axes.set_interpolation(interpolation) self.refresh() def set_title(self, s): if self.is_shown: self.axes.set_title(s) self.refresh() def open_zoom(self, center=None, size=None): '''Opens a separate window with a zoomed view. If a ctrl-lclick event occurs in the original view, the zoomed window will pan to the location of the click event. Arguments: `center` (two-tuple of int): Initial (row, col) of the zoomed view. `size` (int): Width and height (in source image pixels) of the initial zoomed view. Returns: A new ImageView object for the zoomed view. ''' from spectral import settings import matplotlib.pyplot as plt if size is None: size = settings.imshow_zoom_pixel_width (nrows, ncols) = self._image_shape fig_kwargs = {} if settings.imshow_zoom_figure_width is not None: width = settings.imshow_zoom_figure_width fig_kwargs['figsize'] = (width, width) fig = plt.figure(**fig_kwargs) view = ImageView(source=self.source) view.set_data(self.data, self.bands, **self.rgb_kwargs) view.set_classes(self.classes, self.class_colors) view.imshow_data_kwargs = self.imshow_data_kwargs.copy() kwargs = {'extent': (-0.5, ncols - 0.5, nrows - 0.5, -0.5)} view.imshow_data_kwargs.update(kwargs) view.imshow_class_kwargs = self.imshow_class_kwargs.copy() view.imshow_class_kwargs.update(kwargs) view.set_display_mode(self.display_mode) view.callbacks_common = self.callbacks_common view.show(fignum=fig.number, mode=self.display_mode) view.axes.set_xlim(0, size) view.axes.set_ylim(size, 0) view.interpolation = 'nearest' if center is not None: view.pan_to(*center) view.cb_parent_pan = ParentViewPanCallback(view, self) view.cb_parent_pan.connect() return view def pan_to(self, row, col): '''Centers view on pixel coordinate (row, col).''' if self.axes is None: raise Exception('Cannot pan image until it is shown.') (xmin, xmax) = self.axes.get_xlim() (ymin, ymax) = self.axes.get_ylim() xrange_2 = (xmax - xmin) / 2.0 yrange_2 = (ymax - ymin) / 2.0 self.axes.set_xlim(col - xrange_2, col + xrange_2) self.axes.set_ylim(row - yrange_2, row + yrange_2) self.axes.figure.canvas.draw() def zoom(self, scale): '''Zooms view in/out (`scale` > 1 zooms in).''' (xmin, xmax) = self.axes.get_xlim() (ymin, ymax) = self.axes.get_ylim() x = (xmin + xmax) / 2.0 y = (ymin + ymax) / 2.0 dx = (xmax - xmin) / 2.0 / scale dy = (ymax - ymin) / 2.0 / scale self.axes.set_xlim(x - dx, x + dx) self.axes.set_ylim(y - dy, y + dy) self.refresh() def format_coord(self, x, y): '''Formats pixel coordinate string displayed in the window.''' (nrows, ncols) = self._image_shape if x < -0.5 or x > ncols - 0.5 or y < -0.5 or y > nrows - 0.5: return "" (r, c) = xy_to_rowcol(x, y) s = 'pixel=[%d,%d]' % (r, c) if self.classes is not None: try: s += ' class=%d' % self.classes[r, c] except: pass return s def __str__(self): meta = self.data_rgb_meta s = 'ImageView object:\n' if 'bands' in meta: s += ' {0:<20}: {1}\n'.format("Display bands", meta['bands']) if self.interpolation == None: interp = "<default>" else: interp = self.interpolation s += ' {0:<20}: {1}\n'.format("Interpolation", interp) if 'rgb range' in meta: s += ' {0:<20}:\n'.format("RGB data limits") for (c, r) in zip('RGB', meta['rgb range']): s += ' {0}: {1}\n'.format(c, str(r)) return s def __repr__(self): return str(self)
class Cursor: """A cursor for selecting Matplotlib artists. Attributes ---------- bindings : dict See the *bindings* keyword argument to the constructor. annotation_kwargs : dict See the *annotation_kwargs* keyword argument to the constructor. annotation_positions : dict See the *annotation_positions* keyword argument to the constructor. highlight_kwargs : dict See the *highlight_kwargs* keyword argument to the constructor. """ _keep_alive = WeakKeyDictionary() def __init__(self, artists, *, multiple=False, highlight=False, hover=False, bindings=None, annotation_kwargs=None, annotation_positions=None, highlight_kwargs=None): """Construct a cursor. Parameters ---------- artists : List[Artist] A list of artists that can be selected by this cursor. multiple : bool, optional Whether multiple artists can be "on" at the same time (defaults to False). highlight : bool, optional Whether to also highlight the selected artist. If so, "highlighter" artists will be placed as the first item in the :attr:`extras` attribute of the `Selection`. hover : bool, optional Whether to select artists upon hovering instead of by clicking. (Hovering over an artist while a button is pressed will not trigger a selection; right clicking on an annotation will still remove it.) bindings : dict, optional A mapping of button and keybindings to actions. Valid entries are: ================ ================================================== 'select' mouse button to select an artist (default: 1) 'deselect' mouse button to deselect an artist (default: 3) 'left' move to the previous point in the selected path, or to the left in the selected image (default: shift+left) 'right' move to the next point in the selected path, or to the right in the selected image (default: shift+right) 'up' move up in the selected image (default: shift+up) 'down' move down in the selected image (default: shift+down) 'toggle_enabled' toggle whether the cursor is active (default: e) 'toggle_visible' toggle default cursor visibility and apply it to all cursors (default: v) ================ ================================================== Missing entries will be set to the defaults. In order to not assign any binding to an action, set it to ``None``. annotation_kwargs : dict, optional Keyword argments passed to the `annotate <matplotlib.axes.Axes.annotate>` call. annotation_positions : List[dict], optional List of positions tried by the annotation positioning algorithm. highlight_kwargs : dict, optional Keyword arguments used to create a highlighted artist. """ artists = list(artists) # Be careful with GC. self._artists = [weakref.ref(artist) for artist in artists] for artist in artists: type(self)._keep_alive.setdefault(artist, set()).add(self) self._multiple = multiple self._highlight = highlight self._visible = True self._enabled = True self._selections = [] self._last_auto_position = None self._callbacks = CallbackRegistry() connect_pairs = [("key_press_event", self._on_key_press)] if hover: if multiple: raise ValueError("'hover' and 'multiple' are incompatible") connect_pairs += [ ("motion_notify_event", self._hover_handler), ("button_press_event", self._hover_handler), ] else: connect_pairs += [("button_press_event", self._nonhover_handler)] self._disconnectors = [ partial(canvas.mpl_disconnect, canvas.mpl_connect(*pair)) for pair in connect_pairs for canvas in {artist.figure.canvas for artist in artists} ] bindings = dict( ChainMap(bindings if bindings is not None else {}, _default_bindings)) unknown_bindings = set(bindings) - set(_default_bindings) if unknown_bindings: raise ValueError("Unknown binding(s): {}".format(", ".join( sorted(unknown_bindings)))) duplicate_bindings = [ k for k, v in Counter(list(bindings.values())).items() if v > 1 ] if duplicate_bindings: raise ValueError("Duplicate binding(s): {}".format(", ".join( sorted(map(str, duplicate_bindings))))) self.bindings = bindings self.annotation_kwargs = (annotation_kwargs if annotation_kwargs is not None else copy.deepcopy(_default_annotation_kwargs)) self.annotation_positions = ( annotation_positions if annotation_positions is not None else copy.deepcopy(_default_annotation_positions)) self.highlight_kwargs = (highlight_kwargs if highlight_kwargs is not None else copy.deepcopy(_default_highlight_kwargs)) @property def artists(self): """The tuple of selectable artists. """ # Work around matplotlib/matplotlib#6982: `cla()` does not clear # `.axes`. return tuple(filter(_is_alive, (ref() for ref in self._artists))) @property def enabled(self): """Whether clicks are registered for picking and unpicking events. """ return self._enabled @enabled.setter def enabled(self, value): self._enabled = value @property def selections(self): """The tuple of current `Selection`\\s. """ for sel in self._selections: if sel.annotation.axes is None: raise RuntimeError("Annotation unexpectedly removed; " "use 'cursor.remove_selection' instead") return tuple(self._selections) @property def visible(self): """Whether selections are visible by default. Setting this property also updates the visibility status of current selections. """ return self._visible @visible.setter def visible(self, value): self._visible = value for sel in self.selections: sel.annotation.set_visible(value) sel.annotation.figure.canvas.draw_idle() def add_selection(self, pi): """Create an annotation for a `Selection` and register it. Returns a new `Selection`, that has been registered by the `Cursor`, with the added annotation set in the :attr:`annotation` field and, if applicable, the highlighting artist in the :attr:`extras` field. Emits the ``"add"`` event with the new `Selection` as argument. When the event is emitted, the position of the annotation is temporarily set to ``(nan, nan)``; if this position is not explicitly set by a callback, then a suitable position will be automatically computed. Likewise, if the text alignment is not explicitly set but the position is, then a suitable alignment will be automatically computed. """ # pi: "pick_info", i.e. an incomplete selection. # Pre-fetch the figure and axes, as callbacks may actually unset them. figure = pi.artist.figure axes = pi.artist.axes if axes.get_renderer_cache() is None: figure.canvas.draw() # Needed by draw_artist below anyways. renderer = pi.artist.axes.get_renderer_cache() ann = pi.artist.axes.annotate(_pick_info.get_ann_text(*pi), xy=pi.target, xytext=(np.nan, np.nan), ha=_MarkedStr("center"), va=_MarkedStr("center"), visible=self.visible, **self.annotation_kwargs) ann.draggable(use_blit=True) extras = [] if self._highlight: hl = self.add_highlight(*pi) if hl: extras.append(hl) sel = pi._replace(annotation=ann, extras=extras) self._selections.append(sel) self._callbacks.process("add", sel) # Check that `ann.axes` is still set, as callbacks may have removed the # annotation. if ann.axes and ann.xyann == (np.nan, np.nan): fig_bbox = figure.get_window_extent() ax_bbox = axes.get_window_extent() overlaps = [] for idx, annotation_position in enumerate( self.annotation_positions): ann.set(**annotation_position) # Work around matplotlib/matplotlib#7614: position update is # missing. ann.update_positions(renderer) bbox = ann.get_window_extent(renderer) overlaps.append(( _get_rounded_intersection_area(fig_bbox, bbox), _get_rounded_intersection_area(ax_bbox, bbox), # Avoid needlessly jumping around by breaking ties using # the last used position as default. idx == self._last_auto_position, )) auto_position = max(range(len(overlaps)), key=overlaps.__getitem__) ann.set(**self.annotation_positions[auto_position]) self._last_auto_position = auto_position else: if isinstance(ann.get_ha(), _MarkedStr): ann.set_ha({ -1: "right", 0: "center", 1: "left" }[np.sign(np.nan_to_num(ann.xyann[0]))]) if isinstance(ann.get_va(), _MarkedStr): ann.set_va({ -1: "top", 0: "center", 1: "bottom" }[np.sign(np.nan_to_num(ann.xyann[1]))]) if (extras or len(self.selections) > 1 and not self._multiple or not figure.canvas.supports_blit): # Either: # - there may be more things to draw, or # - annotation removal will make a full redraw necessary, or # - blitting is not (yet) supported. figure.canvas.draw_idle() elif ann.axes: # Fast path, only needed if the annotation has not been immediately # removed. figure.draw_artist(ann) # Explicit argument needed on MacOSX backend. figure.canvas.blit(figure.bbox) # Removal comes after addition so that the fast blitting path works. if not self._multiple: for sel in self.selections[:-1]: self.remove_selection(sel) return sel def add_highlight(self, artist, *args, **kwargs): """Create, add and return a highlighting artist. This method is should be called with an "unpacked" `Selection`, possibly with some fields set to None. It is up to the caller to register the artist with the proper `Selection` in order to ensure cleanup upon deselection. """ hl = _pick_info.make_highlight( artist, *args, **ChainMap({"highlight_kwargs": self.highlight_kwargs}, kwargs)) if hl: artist.axes.add_artist(hl) return hl def connect(self, event, func=None): """Connect a callback to a `Cursor` event; return the callback id. Two classes of event can be emitted, both with a `Selection` as single argument: - ``"add"`` when a `Selection` is added, and - ``"remove"`` when a `Selection` is removed. The callback registry relies on Matplotlib's implementation; in particular, only weak references are kept for bound methods. This method is can also be used as a decorator:: @cursor.connect("add") def on_add(sel): ... Examples of callbacks:: # Change the annotation text and alignment: lambda sel: sel.annotation.set( text=sel.artist.get_label(), # or use e.g. sel.target.index ha="center", va="bottom") # Make label non-draggable: lambda sel: sel.draggable(False) """ if event not in ["add", "remove"]: raise ValueError("Invalid cursor event: {}".format(event)) if func is None: return partial(self.connect, event) return self._callbacks.connect(event, func) def disconnect(self, cid): """Disconnect a previously connected callback id. """ self._callbacks.disconnect(cid) def remove(self): """Remove a cursor. Remove all `Selection`\\s, disconnect all callbacks, and allow the cursor to be garbage collected. """ for disconnectors in self._disconnectors: disconnectors() for sel in self.selections: self.remove_selection(sel) for s in type(self)._keep_alive.values(): with suppress(KeyError): s.remove(self) def _nonhover_handler(self, event): if event.name == "button_press_event": if event.button == self.bindings["select"]: self._on_select_button_press(event) if event.button == self.bindings["deselect"]: self._on_deselect_button_press(event) def _hover_handler(self, event): if event.name == "motion_notify_event" and event.button is None: # Filter away events where the mouse is pressed, in particular to # avoid conflicts between hover and draggable. self._on_select_button_press(event) elif (event.name == "button_press_event" and event.button == self.bindings["deselect"]): # Still allow removing the annotation by right clicking. self._on_deselect_button_press(event) def _filter_mouse_event(self, event): # Accept the event iff we are enabled, and either # - no other widget is active, and this is not the second click of a # double click (to prevent double selection), or # - another widget is active, and this is a double click (to bypass # the widget lock). return self.enabled and event.canvas.widgetlock.locked( ) == event.dblclick def _on_select_button_press(self, event): if not self._filter_mouse_event(event): return # Work around lack of support for twinned axes. per_axes_event = { ax: _reassigned_axes_event(event, ax) for ax in {artist.axes for artist in self.artists} } pis = [] for artist in self.artists: if (artist.axes is None # Removed or figure-level artist. or event.canvas is not artist.figure.canvas or not artist.axes.contains(event)[0]): # Cropped by axes. continue pi = _pick_info.compute_pick(artist, per_axes_event[artist.axes]) if pi: pis.append(pi) if not pis: return self.add_selection(min(pis, key=lambda pi: pi.dist)) def _on_deselect_button_press(self, event): if not self._filter_mouse_event(event): return for sel in self.selections: ann = sel.annotation if event.canvas is not ann.figure.canvas: continue contained, _ = ann.contains(event) if contained: self.remove_selection(sel) def _on_key_press(self, event): if event.key == self.bindings["toggle_enabled"]: self.enabled = not self.enabled elif event.key == self.bindings["toggle_visible"]: self.visible = not self.visible try: sel = self.selections[-1] except IndexError: return for key in ["left", "right", "up", "down"]: if event.key == self.bindings[key]: self.remove_selection(sel) self.add_selection(_pick_info.move(*sel, key=key)) break def remove_selection(self, sel): """Remove a `Selection`. """ self._selections.remove(sel) # <artist>.figure will be unset so we save them first. figures = {artist.figure for artist in [sel.annotation] + sel.extras} # ValueError is raised if the artist has already been removed. with suppress(ValueError): sel.annotation.remove() for artist in sel.extras: with suppress(ValueError): artist.remove() self._callbacks.process("remove", sel) for figure in figures: figure.canvas.draw_idle()
def __init__(self, artists, *, multiple=False, highlight=False, hover=False, bindings=None, annotation_kwargs=None, annotation_positions=None, highlight_kwargs=None): """Construct a cursor. Parameters ---------- artists : List[Artist] A list of artists that can be selected by this cursor. multiple : bool, optional Whether multiple artists can be "on" at the same time (defaults to False). highlight : bool, optional Whether to also highlight the selected artist. If so, "highlighter" artists will be placed as the first item in the :attr:`extras` attribute of the `Selection`. hover : bool, optional Whether to select artists upon hovering instead of by clicking. (Hovering over an artist while a button is pressed will not trigger a selection; right clicking on an annotation will still remove it.) bindings : dict, optional A mapping of button and keybindings to actions. Valid entries are: ================ ================================================== 'select' mouse button to select an artist (default: 1) 'deselect' mouse button to deselect an artist (default: 3) 'left' move to the previous point in the selected path, or to the left in the selected image (default: shift+left) 'right' move to the next point in the selected path, or to the right in the selected image (default: shift+right) 'up' move up in the selected image (default: shift+up) 'down' move down in the selected image (default: shift+down) 'toggle_enabled' toggle whether the cursor is active (default: e) 'toggle_visible' toggle default cursor visibility and apply it to all cursors (default: v) ================ ================================================== Missing entries will be set to the defaults. In order to not assign any binding to an action, set it to ``None``. annotation_kwargs : dict, optional Keyword argments passed to the `annotate <matplotlib.axes.Axes.annotate>` call. annotation_positions : List[dict], optional List of positions tried by the annotation positioning algorithm. highlight_kwargs : dict, optional Keyword arguments used to create a highlighted artist. """ artists = list(artists) # Be careful with GC. self._artists = [weakref.ref(artist) for artist in artists] for artist in artists: type(self)._keep_alive.setdefault(artist, set()).add(self) self._multiple = multiple self._highlight = highlight self._visible = True self._enabled = True self._selections = [] self._last_auto_position = None self._callbacks = CallbackRegistry() connect_pairs = [("key_press_event", self._on_key_press)] if hover: if multiple: raise ValueError("'hover' and 'multiple' are incompatible") connect_pairs += [ ("motion_notify_event", self._hover_handler), ("button_press_event", self._hover_handler), ] else: connect_pairs += [("button_press_event", self._nonhover_handler)] self._disconnectors = [ partial(canvas.mpl_disconnect, canvas.mpl_connect(*pair)) for pair in connect_pairs for canvas in {artist.figure.canvas for artist in artists} ] bindings = dict( ChainMap(bindings if bindings is not None else {}, _default_bindings)) unknown_bindings = set(bindings) - set(_default_bindings) if unknown_bindings: raise ValueError("Unknown binding(s): {}".format(", ".join( sorted(unknown_bindings)))) duplicate_bindings = [ k for k, v in Counter(list(bindings.values())).items() if v > 1 ] if duplicate_bindings: raise ValueError("Duplicate binding(s): {}".format(", ".join( sorted(map(str, duplicate_bindings))))) self.bindings = bindings self.annotation_kwargs = (annotation_kwargs if annotation_kwargs is not None else copy.deepcopy(_default_annotation_kwargs)) self.annotation_positions = ( annotation_positions if annotation_positions is not None else copy.deepcopy(_default_annotation_positions)) self.highlight_kwargs = (highlight_kwargs if highlight_kwargs is not None else copy.deepcopy(_default_highlight_kwargs))
class scatter_selector(AxesWidget): """ A widget for selecting a point in a scatter plot. callback will receive (index, (x, y)) """ def __init__(self, ax, x, y, pickradius=5, which_button=1, **kwargs): """ Create the scatter plot and selection machinery. Parameters ---------- ax : Axes The Axes on which to make the scatter plot x, y : float or array-like, shape (n, ) The data positions. pickradius : float Pick radius, in points. which_button : int, default: 1 Where 1=left, 2=middle, 3=right Other Parameters ---------------- **kwargs : arguments to scatter Other keyword arguments are passed directly to the ``ax.scatter`` command """ super().__init__(ax) self.scatter = ax.scatter(x, y, **kwargs, picker=True) self.scatter.set_pickradius(pickradius) self._observers = CallbackRegistry() self._x = x self._y = y self._button = which_button self.connect_event("pick_event", self._on_pick) self._init_val() def _init_val(self): self.val = (0, (self._x[0], self._y[0])) def _on_pick(self, event): if event.mouseevent.button == self._button: idx = event.ind[0] x = self._x[idx] y = self._y[idx] self._process(idx, (x, y)) def _process(idx, val): self._observers.process("picked", idx, val) def on_changed(self, func): """ When a point is clicked calll *func* with the newly selected point Parameters ---------- func : callable Function to call when slider is changed. The function must accept a (int, tuple(float, float)) as its arguments. Returns ------- int Connection id (which can be used to disconnect *func*) """ return self._observers.connect("picked", lambda idx, val: func(idx, val))
class Cursor: """A cursor for selecting artists on a matplotlib figure. """ _keep_alive = WeakKeyDictionary() def __init__(self, artists, *, multiple=False, highlight=False, hover=False, bindings=default_bindings): """Construct a cursor. Parameters ---------- artists : List[Artist] A list of artists that can be selected by this cursor. multiple : bool Whether multiple artists can be "on" at the same time (defaults to False). highlight : bool Whether to also highlight the selected artist. If so, "highlighter" artists will be placed as the first item in the :attr:`extras` attribute of the `Selection`. bindings : dict A mapping of button and keybindings to actions. Valid entries are: =================== =============================================== 'select' mouse button to select an artist (default: 1) 'deselect' mouse button to deselect an artist (default: 3) 'left' move to the previous point in the selected path, or to the left in the selected image (default: shift+left) 'right' move to the next point in the selected path, or to the right in the selected image (default: shift+right) 'up' move up in the selected image (default: shift+up) 'down' move down in the selected image (default: shift+down) 'toggle_visibility' toggle visibility of all cursors (default: d) 'toggle_enabled' toggle whether the cursor is active (default: t) =================== =============================================== hover : bool Whether to select artists upon hovering instead of by clicking. """ artists = list(artists) # Be careful with GC. self._artists = [weakref.ref(artist) for artist in artists] for artist in artists: type(self)._keep_alive.setdefault(artist, []).append(self) self._multiple = multiple self._highlight = highlight self._axes = {artist.axes for artist in artists} self._enabled = True self._selections = [] self._callbacks = CallbackRegistry() connect_pairs = [("key_press_event", self._on_key_press)] if hover: if multiple: raise ValueError("`hover` and `multiple` are incompatible") connect_pairs += [ ("motion_notify_event", self._on_select_button_press)] else: connect_pairs += [ ("button_press_event", self._on_button_press)] self._disconnect_cids = [ partial(canvas.mpl_disconnect, canvas.mpl_connect(*pair)) for pair in connect_pairs for canvas in {artist.figure.canvas for artist in artists}] bindings = {**default_bindings, **bindings} if set(bindings) != set(default_bindings): raise ValueError("Unknown bindings") actually_bound = {k: v for k, v in bindings.items() if v is not None} if len(set(actually_bound.values())) != len(actually_bound): raise ValueError("Duplicate bindings") self._bindings = bindings @property def enabled(self): """Whether clicks are registered for picking and unpicking events. """ return self._enabled @enabled.setter def enabled(self, value): self._enabled = value @property def artists(self): """The tuple of selectable artists. """ return tuple(filter(None, (ref() for ref in self._artists))) @property def selections(self): """The tuple of current `Selection`\\s. """ return tuple(self._selections) def add_selection(self, pi): """Create an annotation for a `Selection` and register it. Returns a new `Selection`, that has been registered by the `Cursor`, with the added annotation set in the :attr:`annotation` field and, if applicable, the highlighting artist in the :attr:`extras` field. Emits the ``"add"`` event with the new `Selection` as argument. """ # pi: "pick_info", i.e. an incomplete selection. ann = pi.artist.axes.annotate( _pick_info.get_ann_text(*pi), xy=pi.target, **default_annotation_kwargs) ann.draggable(use_blit=True) extras = [] if self._highlight: extras.append(self.add_highlight(pi.artist)) if not self._multiple: while self._selections: self._remove_selection(self._selections[-1]) sel = pi._replace(annotation=ann, extras=extras) self._selections.append(sel) self._callbacks.process("add", sel) sel.artist.figure.canvas.draw_idle() return sel def add_highlight(self, artist): """Create, add and return a highlighting artist. It is up to the caller to register the artist with the proper `Selection` in order to ensure cleanup upon deselection. """ hl = copy.copy(artist) hl.set(**default_highlight_kwargs) artist.axes.add_artist(hl) return hl def connect(self, event, func=None): """Connect a callback to a `Cursor` event; return the callback id. Two classes of event can be emitted, both with a `Selection` as single argument: - ``"add"`` when a `Selection` is added, and - ``"remove"`` when a `Selection` is removed. The callback registry relies on :mod:`matplotlib`'s implementation; in particular, only weak references are kept for bound methods. This method is can also be used as a decorator:: @cursor.connect("add") def on_add(sel): ... """ if event not in ["add", "remove"]: raise ValueError("Invalid cursor event: {}".format(event)) if func is None: return partial(self.connect, event) return self._callbacks.connect(event, func) def disconnect(self, cid): """Disconnect a previously connected callback id. """ self._callbacks.disconnect(cid) def remove(self): """Remove all `Selection`\\s and disconnect all callbacks. """ for disconnect_cid in self._disconnect_cids: disconnect_cid() while self._selections: self._remove_selection(self._selections[-1]) def _on_button_press(self, event): if event.button == self._bindings["select"]: self._on_select_button_press(event) if event.button == self._bindings["deselect"]: self._on_deselect_button_press(event) def _filter_mouse_event(self, event): # Accept the event iff we are enabled, and either # - no other widget is active, and this is not the second click of a # double click (to prevent double selection), or # - another widget is active, and this is a double click (to bypass # the widget lock). return (self.enabled and event.canvas.widgetlock.locked() == event.dblclick) def _on_select_button_press(self, event): if not self._filter_mouse_event(event): return # Work around lack of support for twinned axes. per_axes_event = {ax: _reassigned_axes_event(event, ax) for ax in self._axes} pis = [] for artist in self.artists: if (artist.axes is None # Removed or figure-level artist. or event.canvas is not artist.figure.canvas or not artist.axes.contains(event)[0]): # Cropped by axes. continue pi = _pick_info.compute_pick(artist, per_axes_event[artist.axes]) if pi: pis.append(pi) if not pis: return self.add_selection(min(pis, key=lambda pi: pi.dist)) def _on_deselect_button_press(self, event): if not self._filter_mouse_event(event): return for sel in self._selections: ann = sel.annotation if event.canvas is not ann.figure.canvas: continue contained, _ = ann.contains(event) if contained: self._remove_selection(sel) def _on_key_press(self, event): if event.key == self._bindings["toggle_enabled"]: self.enabled = not self.enabled elif event.key == self._bindings["toggle_visibility"]: for sel in self._selections: sel.annotation.set_visible(not sel.annotation.get_visible()) sel.annotation.figure.canvas.draw_idle() if self._selections: sel = self._selections[-1] else: return for key in ["left", "right", "up", "down"]: if event.key == self._bindings[key]: self._remove_selection(sel) self.add_selection(_pick_info.move(*sel, key=key)) break def _remove_selection(self, sel): self._selections.remove(sel) # Work around matplotlib/matplotlib#6785. draggable = sel.annotation._draggable try: draggable.disconnect() sel.annotation.figure.canvas.mpl_disconnect( sel.annotation._draggable._c1) except AttributeError: pass # (end of workaround). # <artist>.figure will be unset so we save them first. figures = {artist.figure for artist in [sel.annotation, *sel.extras]} # ValueError is raised if the artist has already been removed. with suppress(ValueError): sel.annotation.remove() for artist in sel.extras: with suppress(ValueError): artist.remove() self._callbacks.process("remove", sel) for figure in figures: figure.canvas.draw_idle()
class ImageView(object): '''Class to manage events and data associated with image raster views. In most cases, it is more convenient to simply call :func:`~spectral.graphics.spypylab.imshow`, which creates, displays, and returns an :class:`ImageView` object. Creating an :class:`ImageView` object directly (or creating an instance of a subclass) enables additional customization of the image display (e.g., overriding default event handlers). If the object is created directly, call the :meth:`show` method to display the image. The underlying image display functionality is implemented via :func:`matplotlib.pyplot.imshow`. ''' selector_rectprops = dict(facecolor='red', edgecolor = 'black', alpha=0.5, fill=True) selector_lineprops = dict(color='black', linestyle='-', linewidth = 2, alpha=0.5) def __init__(self, data=None, bands=None, classes=None, source=None, **kwargs): ''' Arguments: `data` (ndarray or :class:`SpyFile`): The source of RGB bands to be displayed. with shape (R, C, B). If the shape is (R, C, 3), the last dimension is assumed to provide the red, green, and blue bands (unless the `bands` argument is provided). If :math:`B > 3` and `bands` is not provided, the first, middle, and last band will be used. `bands` (triplet of integers): Specifies which bands in `data` should be displayed as red, green, and blue, respectively. `classes` (ndarray of integers): An array of integer-valued class labels with shape (R, C). If the `data` argument is provided, the shape must match the first two dimensions of `data`. `source` (ndarray or :class:`SpyFile`): The source of spectral data associated with the image display. This optional argument is used to access spectral data (e.g., to generate a spectrum plot when a user double-clicks on the image display. Keyword arguments: Any keyword that can be provided to :func:`~spectral.graphics.graphics.get_rgb` or :func:`matplotlib.imshow`. ''' import spectral from spectral import settings self.is_shown = False self.imshow_data_kwargs = {'cmap': settings.imshow_float_cmap} self.rgb_kwargs = {} self.imshow_class_kwargs = {'zorder': 1} self.data = data self.data_rgb = None self.data_rgb_meta = {} self.classes = None self.class_rgb = None self.source = None self.bands = bands self.data_axes = None self.class_axes = None self.axes = None self._image_shape = None self.display_mode = None self._interpolation = None self.selection = None self.interpolation = kwargs.get('interpolation', settings.imshow_interpolation) if data is not None: self.set_data(data, bands, **kwargs) if classes is not None: self.set_classes(classes, **kwargs) if source is not None: self.set_source(source) self.class_colors = spectral.spy_colors self.spectrum_plot_fig_id = None self.parent = None self.selector = None self._on_parent_click_cid = None self._class_alpha = settings.imshow_class_alpha # Callbacks for events associated specifically with this window. self.callbacks = None # A sharable callback registry for related windows. If this # CallbackRegistry is set prior to calling ImageView.show (e.g., by # setting it equal to the `callbacks_common` member of another # ImageView object), then the registry will be shared. Otherwise, a new # callback registry will be created for this ImageView. self.callbacks_common = None check_disable_mpl_callbacks() def set_data(self, data, bands=None, **kwargs): '''Sets the data to be shown in the RGB channels. Arguments: `data` (ndarray or SpyImage): If `data` has more than 3 bands, the `bands` argument can be used to specify which 3 bands to display. `data` will be passed to `get_rgb` prior to display. `bands` (3-tuple of int): Indices of the 3 bands to display from `data`. Keyword Arguments: Any valid keyword for `get_rgb` or `matplotlib.imshow` can be given. ''' from .graphics import _get_rgb_kwargs self.data = data self.bands = bands rgb_kwargs = {} for k in _get_rgb_kwargs: if k in kwargs: rgb_kwargs[k] = kwargs.pop(k) self.set_rgb_options(**rgb_kwargs) self._update_data_rgb() if self._image_shape is None: self._image_shape = data.shape[:2] elif data.shape[:2] != self._image_shape: raise ValueError('Image shape is inconsistent with previously ' \ 'set data.') self.imshow_data_kwargs.update(kwargs) if 'interpolation' in self.imshow_data_kwargs: self.interpolation = self.imshow_data_kwargs['interpolation'] self.imshow_data_kwargs.pop('interpolation') if len(kwargs) > 0 and self.is_shown: msg = 'Keyword args to set_data only have an effect if ' \ 'given before the image is shown.' warnings.warn(UserWarning(msg)) if self.is_shown: self.refresh() def set_rgb_options(self, **kwargs): '''Sets parameters affecting RGB display of data. Accepts any keyword supported by :func:`~spectral.graphics.graphics.get_rgb`. ''' from .graphics import _get_rgb_kwargs for k in kwargs: if k not in _get_rgb_kwargs: raise ValueError('Unexpected keyword: {0}'.format(k)) self.rgb_kwargs = kwargs.copy() if self.is_shown: self._update_data_rgb() self.refresh() def _update_data_rgb(self): '''Regenerates the RGB values for display.''' from .graphics import get_rgb_meta (self.data_rgb, self.data_rgb_meta) = \ get_rgb_meta(self.data, self.bands, **self.rgb_kwargs) # If it is a gray-scale image, only keep the first RGB component so # matplotlib imshow's cmap can still be used. if self.data_rgb_meta['mode'] == 'monochrome' and \ self.data_rgb.ndim ==3: (self.bands is not None and len(self.bands) == 1) def set_classes(self, classes, colors=None, **kwargs): '''Sets the array of class values associated with the image data. Arguments: `classes` (ndarray of int): `classes` must be an integer-valued array with the same number rows and columns as the display data (if set). `colors`: (array or 3-tuples): Color triplets (with values in the range [0, 255]) that define the colors to be associatd with the integer indices in `classes`. Keyword Arguments: Any valid keyword for `matplotlib.imshow` can be provided. ''' from .graphics import _get_rgb_kwargs self.classes = classes if classes is None: return if self._image_shape is None: self._image_shape = classes.shape[:2] elif classes.shape[:2] != self._image_shape: raise ValueError('Class data shape is inconsistent with ' \ 'previously set data.') if colors is not None: self.class_colors = colors kwargs = dict([item for item in list(kwargs.items()) if item[0] not in \ _get_rgb_kwargs]) self.imshow_class_kwargs.update(kwargs) if 'interpolation' in self.imshow_class_kwargs: self.interpolation = self.imshow_class_kwargs['interpolation'] self.imshow_class_kwargs.pop('interpolation') if len(kwargs) > 0 and self.is_shown: msg = 'Keyword args to set_classes only have an effect if ' \ 'given before the image is shown.' warnings.warn(UserWarning(msg)) if self.is_shown: self.refresh() def set_source(self, source): '''Sets the image data source (used for accessing spectral data). Arguments: `source` (ndarray or :class:`SpyFile`): The source for spectral data associated with the view. ''' self.source = source def show(self, mode=None, fignum=None): '''Renders the image data. Arguments: `mode` (str): Must be one of: "data": Show the data RGB "classes": Shows indexed color for `classes` "overlay": Shows class colors overlaid on data RGB. If `mode` is not provided, a mode will be automatically selected, based on the data set in the ImageView. `fignum` (int): Figure number of the matplotlib figure in which to display the ImageView. If not provided, a new figure will be created. ''' import matplotlib.pyplot as plt from spectral import settings if self.is_shown: msg = 'ImageView.show should only be called once.' warnings.warn(UserWarning(msg)) return set_mpl_interactive() kwargs = {} if fignum is not None: kwargs['num'] = fignum if settings.imshow_figure_size is not None: kwargs['figsize'] = settings.imshow_figure_size plt.figure(**kwargs) if self.data_rgb is not None: self.show_data() if self.classes is not None: self.show_classes() if mode is None: self._guess_mode() else: self.set_display_mode(mode) self.axes.format_coord = self.format_coord self.init_callbacks() self.is_shown = True def init_callbacks(self): '''Creates the object's callback registry and default callbacks.''' from spectral import settings from matplotlib.cbook import CallbackRegistry self.callbacks = CallbackRegistry() # callbacks_common may have been set to a shared external registry # (e.g., to the callbacks_common member of another ImageView object). So # don't create it if it has already been set. if self.callbacks_common is None: self.callbacks_common = CallbackRegistry() # Keyboard callback self.cb_mouse = ImageViewMouseHandler(self) self.cb_mouse.connect() # Mouse callback self.cb_keyboard = ImageViewKeyboardHandler(self) self.cb_keyboard.connect() # Class update event callback def updater(*args, **kwargs): if self.classes is None: self.set_classes(args[0].classes) self.refresh() callback = MplCallback(registry=self.callbacks_common, event='spy_classes_modified', callback=updater) callback.connect() self.cb_classes_modified = callback if settings.imshow_enable_rectangle_selector is False: return try: from matplotlib.widgets import RectangleSelector self.selector = RectangleSelector(self.axes, self._select_rectangle, button=1, useblit=True, spancoords='data', drawtype='box', rectprops = \ self.selector_rectprops) self.selector.set_active(False) except: self.selector = None msg = 'Failed to create RectangleSelector object. Interactive ' \ 'pixel class labeling will be unavailable.' warn(msg) def label_region(self, rectangle, class_id): '''Assigns all pixels in the rectangle to the specified class. Arguments: `rectangle` (4-tuple of integers): Tuple or list defining the rectangle bounds. Should have the form (row_start, row_stop, col_start, col_stop), where the stop indices are not included (i.e., the effect is `classes[row_start:row_stop, col_start:col_stop] = id`. class_id (integer >= 0): The class to which pixels will be assigned. Returns the number of pixels reassigned (the number of pixels in the rectangle whose class has *changed* to `class_id`. ''' if self.classes is None: self.classes = np.zeros(self.data_rgb.shape[:2], dtype=np.int16) r = rectangle n = np.sum(self.classes[r[0]:r[1], r[2]:r[3]] != class_id) if n > 0: self.classes[r[0]:r[1], r[2]:r[3]] = class_id event = SpyMplEvent('spy_classes_modified') event.classes = self.classes event.nchanged = n self.callbacks_common.process('spy_classes_modified', event) # Make selection rectangle go away. self.selector.to_draw.set_visible(False) self.refresh() return n return 0 def _select_rectangle(self, event1, event2): if event1.inaxes is not self.axes or event2.inaxes is not self.axes: self.selection = None return (r1, c1) = xy_to_rowcol(event1.xdata, event1.ydata) (r2, c2) = xy_to_rowcol(event2.xdata, event2.ydata) (r1, r2) = sorted([r1, r2]) (c1, c2) = sorted([c1, c2]) if (r2 < 0) or (r1 >= self._image_shape[0]) or \ (c2 < 0) or (c1 >= self._image_shape[1]): self.selection = None return r1 = max(r1, 0) r2 = min(r2, self._image_shape[0] - 1) c1 = max(c1, 0) c2 = min(c2, self._image_shape[1] - 1) print('Selected region: [%d: %d, %d: %d]' % (r1, r2 + 1, c1, c2 + 1)) self.selection = [r1, r2 + 1, c1, c2 + 1] self.selector.set_active(False) # Make the rectangle display until at least the next event self.selector.to_draw.set_visible(True) self.selector.update() def _guess_mode(self): '''Select an appropriate display mode, based on current data.''' if self.data_rgb is not None: self.set_display_mode('data') elif self.classes is not None: self.set_display_mode('classes') else: raise Exception('Unable to display image: no data set.') def show_data(self): '''Show the image data.''' import matplotlib.pyplot as plt if self.data_axes is not None: msg = 'ImageView.show_data should only be called once.' warnings.warn(UserWarning(msg)) return elif self.data_rgb is None: raise Exception('Unable to display data: data array not set.') if self.axes is not None: # A figure has already been created for the view. Make it current. plt.figure(self.axes.figure.number) self.imshow_data_kwargs['interpolation'] = self._interpolation self.data_axes = plt.imshow(self.data_rgb, **self.imshow_data_kwargs) if self.axes is None: self.axes = self.data_axes.axes def show_classes(self): '''Show the class values.''' import matplotlib.pyplot as plt from matplotlib.colors import ListedColormap, NoNorm from spectral import get_rgb if self.class_axes is not None: msg = 'ImageView.show_classes should only be called once.' warnings.warn(UserWarning(msg)) return elif self.classes is None: raise Exception('Unable to display classes: class array not set.') cm = ListedColormap(np.array(self.class_colors) / 255.) self._update_class_rgb() kwargs = self.imshow_class_kwargs.copy() kwargs.update({'cmap': cm, 'vmin': 0, 'norm': NoNorm(), 'interpolation': self._interpolation}) if self.axes is not None: # A figure has already been created for the view. Make it current. plt.figure(self.axes.figure.number) self.class_axes = plt.imshow(self.class_rgb, **kwargs) if self.axes is None: self.axes = self.class_axes.axes self.class_axes.set_zorder(1) if self.display_mode == 'overlay': self.class_axes.set_alpha(self._class_alpha) else: self.class_axes.set_alpha(1) #self.class_axes.axes.set_axis_bgcolor('black') def refresh(self): '''Updates the displayed data (if it has been shown).''' if self.is_shown: self._update_class_rgb() if self.class_axes is not None: self.class_axes.set_data(self.class_rgb) self.class_axes.set_interpolation(self._interpolation) elif self.display_mode in ('classes', 'overlay'): self.show_classes() if self.data_axes is not None: self.data_axes.set_data(self.data_rgb) self.data_axes.set_interpolation(self._interpolation) elif self.display_mode in ('data', 'overlay'): self.show_data() self.axes.figure.canvas.draw() def _update_class_rgb(self): if self.display_mode == 'overlay': self.class_rgb = np.ma.array(self.classes, mask=(self.classes==0)) else: self.class_rgb = np.array(self.classes) def set_display_mode(self, mode): '''`mode` must be one of ("data", "classes", "overlay").''' if mode not in ('data', 'classes', 'overlay'): raise ValueError('Invalid display mode: ' + repr(mode)) self.display_mode = mode show_data = mode in ('data', 'overlay') if self.data_axes is not None: self.data_axes.set_visible(show_data) show_classes = mode in ('classes', 'overlay') if self.classes is not None and self.class_axes is None: # Class data values were just set self.show_classes() if self.class_axes is not None: self.class_axes.set_visible(show_classes) if mode == 'classes': self.class_axes.set_alpha(1) else: self.class_axes.set_alpha(self._class_alpha) self.refresh() @property def class_alpha(self): '''alpha transparency for the class overlay.''' return self._class_alpha @class_alpha.setter def class_alpha(self, alpha): if alpha < 0 or alpha > 1: raise ValueError('Alpha value must be in range [0, 1].') self._class_alpha = alpha if self.class_axes is not None: self.class_axes.set_alpha(alpha) if self.is_shown: self.refresh() @property def interpolation(self): '''matplotlib pixel interpolation to use in the image display.''' return self._interpolation @interpolation.setter def interpolation(self, interpolation): if interpolation == self._interpolation: return self._interpolation = interpolation if not self.is_shown: return if self.data_axes is not None: self.data_axes.set_interpolation(interpolation) if self.class_axes is not None: self.class_axes.set_interpolation(interpolation) self.refresh() def set_title(self, s): if self.is_shown: self.axes.set_title(s) self.refresh() def open_zoom(self, center=None, size=None): '''Opens a separate window with a zoomed view. If a ctrl-lclick event occurs in the original view, the zoomed window will pan to the location of the click event. Arguments: `center` (two-tuple of int): Initial (row, col) of the zoomed view. `size` (int): Width and height (in source image pixels) of the initial zoomed view. Returns: A new ImageView object for the zoomed view. ''' from spectral import settings import matplotlib.pyplot as plt if size is None: size = settings.imshow_zoom_pixel_width (nrows, ncols) = self._image_shape fig_kwargs = {} if settings.imshow_zoom_figure_width is not None: width = settings.imshow_zoom_figure_width fig_kwargs['figsize'] = (width, width) fig = plt.figure(**fig_kwargs) view = ImageView(source=self.source) view.set_data(self.data, self.bands, **self.rgb_kwargs) view.set_classes(self.classes, self.class_colors) view.imshow_data_kwargs = self.imshow_data_kwargs.copy() kwargs = {'extent': (-0.5, ncols - 0.5, nrows - 0.5, -0.5)} view.imshow_data_kwargs.update(kwargs) view.imshow_class_kwargs = self.imshow_class_kwargs.copy() view.imshow_class_kwargs.update(kwargs) view.set_display_mode(self.display_mode) view.callbacks_common = self.callbacks_common view.show(fignum=fig.number, mode=self.display_mode) view.axes.set_xlim(0, size) view.axes.set_ylim(size, 0) view.interpolation = 'nearest' if center is not None: view.pan_to(*center) view.cb_parent_pan = ParentViewPanCallback(view, self) view.cb_parent_pan.connect() return view def pan_to(self, row, col): '''Centers view on pixel coordinate (row, col).''' if self.axes is None: raise Exception('Cannot pan image until it is shown.') (xmin, xmax) = self.axes.get_xlim() (ymin, ymax) = self.axes.get_ylim() xrange_2 = (xmax - xmin) / 2.0 yrange_2 = (ymax - ymin) / 2.0 self.axes.set_xlim(col - xrange_2, col + xrange_2) self.axes.set_ylim(row - yrange_2, row + yrange_2) self.axes.figure.canvas.draw() def zoom(self, scale): '''Zooms view in/out (`scale` > 1 zooms in).''' (xmin, xmax) = self.axes.get_xlim() (ymin, ymax) = self.axes.get_ylim() x = (xmin + xmax) / 2.0 y = (ymin + ymax) / 2.0 dx = (xmax - xmin) / 2.0 / scale dy = (ymax - ymin) / 2.0 / scale self.axes.set_xlim(x - dx, x + dx) self.axes.set_ylim(y - dy, y + dy) self.refresh() def format_coord(self, x, y): '''Formats pixel coordinate string displayed in the window.''' (nrows, ncols) = self._image_shape if x < -0.5 or x > ncols - 0.5 or y < -0.5 or y > nrows - 0.5: return "" (r, c) = xy_to_rowcol(x, y) s = 'pixel=[%d,%d]' % (r, c) if self.classes is not None: try: s += ' class=%d' % self.classes[r, c] except: pass return s def __str__(self): meta = self.data_rgb_meta s = 'ImageView object:\n' if 'bands' in meta: s += ' {0:<20}: {1}\n'.format("Display bands", meta['bands']) if self.interpolation == None: interp = "<default>" else: interp = self.interpolation s += ' {0:<20}: {1}\n'.format("Interpolation", interp) if 'rgb range' in meta: s += ' {0:<20}:\n'.format("RGB data limits") for (c, r) in zip('RGB', meta['rgb range']): s += ' {0}: {1}\n'.format(c, str(r)) return s def __repr__(self): return str(self)
class NDWindow(wx.Frame): '''A widow class for displaying N-dimensional data points.''' def __init__(self, data, parent, id, *args, **kwargs): global DEFAULT_WIN_SIZE self.kwargs = kwargs self.size = kwargs.get('size', DEFAULT_WIN_SIZE) self.title = kwargs.get('title', 'ND Window') # # Forcing a specific style on the window. # Should this include styles passed? style = wx.DEFAULT_FRAME_STYLE | wx.NO_FULL_REPAINT_ON_RESIZE super(NDWindow, self).__init__(parent, id, self.title, wx.DefaultPosition, wx.Size(*self.size), style, self.title) self.gl_initialized = False attribs = (glcanvas.WX_GL_RGBA, glcanvas.WX_GL_DOUBLEBUFFER, glcanvas.WX_GL_DEPTH_SIZE, settings.WX_GL_DEPTH_SIZE) self.canvas = glcanvas.GLCanvas(self, attribList=attribs) self.canvas.context = wx.glcanvas.GLContext(self.canvas) self._have_glut = False self.clear_color = (0, 0, 0, 0) self.show_axes_tf = True self.point_size = 1.0 self._show_unassigned = True self._refresh_display_lists = False self._click_tolerance = 1 self._display_commands = [] self._selection_box = None self._rgba_indices = None self.mouse_panning = False self.win_pos = (100, 100) self.fovy = 60. self.znear = 0.1 self.zfar = 10.0 self.target_pos = [0.0, 0.0, 0.0] self.camera_pos_rtp = [7.0, 45.0, 30.0] self.up = [0.0, 0.0, 1.0] self.quadrant_mode = None self.mouse_handler = MouseHandler(self) # Set the event handlers. self.canvas.Bind(wx.EVT_ERASE_BACKGROUND, self.on_erase_background) self.Bind(wx.EVT_SIZE, self.on_resize) self.canvas.Bind(wx.EVT_PAINT, self.on_paint) self.canvas.Bind(wx.EVT_LEFT_DOWN, self.mouse_handler.left_down) self.canvas.Bind(wx.EVT_LEFT_UP, self.mouse_handler.left_up) self.canvas.Bind(wx.EVT_MOTION, self.mouse_handler.motion) self.canvas.Bind(wx.EVT_CHAR, self.on_char) self.canvas.Bind(wx.EVT_RIGHT_DOWN, self.right_click) self.canvas.Bind(wx.EVT_CLOSE, self.on_event_close) self.data = data self.classes = kwargs.get('classes', np.zeros(data.shape[:-1], np.int)) self.features = kwargs.get('features', list(range(6))) self.labels = kwargs.get('labels', list(range(data.shape[-1]))) self.max_menu_class = int(np.max(self.classes.ravel() + 1)) from matplotlib.cbook import CallbackRegistry self.callbacks = CallbackRegistry() def on_event_close(self, event=None): pass def right_click(self, event): self.canvas.SetCurrent(self.canvas.context) self.canvas.PopupMenu(MouseMenu(self), event.GetPosition()) def add_display_command(self, cmd): '''Adds a command to be called next time `display` is run.''' self._display_commands.append(cmd) def reset_view_geometry(self): '''Sets viewing geometry to the default view.''' # All grid points will be adjusted to the range [0,1] so this # is a reasonable center coordinate for the scene self.target_pos = np.array([0.0, 0.0, 0.0]) # Specify the camera location in spherical polar coordinates relative # to target_pos. self.camera_pos_rtp = [2.5, 45.0, 30.0] def set_data(self, data, **kwargs): '''Associates N-D point data with the window. ARGUMENTS: data (numpy.ndarray): An RxCxB array of data points to display. KEYWORD ARGUMENTS: classes (numpy.ndarray): An RxC array of integer class labels (zeros means unassigned). features (list): Indices of feautures to display in the octant (see NDWindow.set_octant_display_features for description). ''' import OpenGL.GL as gl try: from OpenGL.GL import glGetIntegerv except: from OpenGL.GL.glget import glGetIntegerv classes = kwargs.get('classes', None) features = kwargs.get('features', list(range(6))) if self.data.shape[2] < 6: features = features[:3] self.quadrant_mode == 'single' # Scale the data set to span an octant data2d = np.array(data.reshape((-1, data.shape[-1]))) mins = np.min(data2d, axis=0) maxes = np.max(data2d, axis=0) denom = (maxes - mins).astype(float) denom = np.where(denom > 0, denom, 1.0) self.data = (data2d - mins) / denom self.data.shape = data.shape self.palette = spy_colors.astype(float) / 255. self.palette[0] = np.array([1.0, 1.0, 1.0]) self.colors = self.palette[self.classes.ravel()].reshape( self.data.shape[:2] + (3, )) self.colors = (self.colors * 255).astype('uint8') colors = np.ones((self.colors.shape[:-1]) + (4, ), 'uint8') colors[:, :, :-1] = self.colors self.colors = colors self._refresh_display_lists = True self.set_octant_display_features(features) # Determine the bit masks to use when using RGBA components for # identifying pixel IDs. components = [ gl.GL_RED_BITS, gl.GL_GREEN_BITS, gl.GL_GREEN_BITS, gl.GL_ALPHA_BITS ] self._rgba_bits = [min(8, glGetIntegerv(i)) for i in components] self._low_bits = [min(8, 8 - self._rgba_bits[i]) for i in range(4)] self._rgba_masks = \ [(2**self._rgba_bits[i] - 1) << (8 - self._rgba_bits[i]) for i in range(4)] # Determine how many times the scene will need to be rendered in the # background to extract the pixel's row/col index. N = self.data.shape[0] * self.data.shape[1] if N > 2**sum(self._rgba_bits): raise Exception( 'Insufficient color bits (%d) for N-D window display' % sum(self._rgba_bits)) self.reset_view_geometry() def set_octant_display_features(self, features): '''Specifies features to be displayed in each 3-D coordinate octant. `features` can be any of the following: A length-3 list of integer feature IDs: In this case, the data points will be displayed in the positive x,y,z octant using features associated with the 3 integers. A length-6 list if integer feature IDs: In this case, each integer specifies a single feature index to be associated with the coordinate semi-axes x, y, z, -x, -y, and -z (in that order). Each octant will display data points using the features associated with the 3 semi-axes for that octant. A length-8 list of length-3 lists of integers: In this case, each length-3 list specfies the features to be displayed in a single octants (the same semi-axis can be associated with different features in different octants). Octants are ordered starting with the postive x,y,z octant and procede counterclockwise around the z-axis, then procede similarly around the negative half of the z-axis. An octant triplet can be specified as None instead of a list, in which case nothing will be rendered in that octant. ''' if features is None: features = list(range(6)) if len(features) == 3: self.octant_features = [features] + [None] * 7 new_quadrant_mode = 'single' self.target_pos = np.array([0.5, 0.5, 0.5]) elif len(features) == 6: self.octant_features = create_mirrored_octants(features) new_quadrant_mode = 'mirrored' self.target_pos = np.array([0.0, 0.0, 0.0]) else: self.octant_features = features new_quadrant_mode = 'independent' self.target_pos = np.array([0.0, 0.0, 0.0]) if new_quadrant_mode != self.quadrant_mode: print('Setting quadrant display mode to %s.' % new_quadrant_mode) self.quadrant_mode = new_quadrant_mode self._refresh_display_lists = True def create_display_lists(self, npass=-1, **kwargs): '''Creates or updates the display lists for image data. ARGUMENTS: `npass` (int): When defaulted to -1, the normal image data display lists are created. When >=0, `npass` represents the rendering pass for identifying image pixels in the scene by their unique colors. KEYWORD ARGS: `indices` (list of ints): An optional list of N-D image pixels to display. ''' import OpenGL.GL as gl gl.glEnableClientState(gl.GL_COLOR_ARRAY) gl.glEnableClientState(gl.GL_VERTEX_ARRAY) gl.glPointSize(self.point_size) gl.glColorPointerub(self.colors) (R, C, B) = self.data.shape indices = kwargs.get('indices', None) if indices is None: indices = np.arange(R * C) if not self._show_unassigned: indices = indices[self.classes.ravel() != 0] self._display_indices = indices # RGB pixel indices for selecting pixels with the mouse gl.glPointSize(self.point_size) if npass < 0: # Colors are associated with image pixel classes. gl.glColorPointerub(self.colors) else: if self._rgba_indices is None: # Generate unique colors that correspond to each pixel's ID # so that the color can be used to identify the pixel. color_indices = np.arange(R * C) rgba = np.zeros((len(color_indices), 4), 'uint8') for i in range(4): shift = sum(self._rgba_bits[0:i]) - self._low_bits[i] if shift > 0: rgba[:, i] = ( color_indices >> shift) & self._rgba_masks[i] else: rgba[:, i] = (color_indices << self._low_bits[i]) \ & self._rgba_masks[i] self._rgba_indices = rgba gl.glColorPointerub(self._rgba_indices) # Generate a display list for each octant of the 3-D window. for (i, octant) in enumerate(self.octant_features): if octant is not None: data = np.take(self.data, octant, axis=2).reshape((-1, 3)) data *= octant_coeffs[i] gl.glVertexPointerf(data) gl.glNewList(self.gllist_id + i + 1, gl.GL_COMPILE) gl.glDrawElementsui(gl.GL_POINTS, indices) gl.glEndList() else: # Create an empty draw list gl.glNewList(self.gllist_id + i + 1, gl.GL_COMPILE) gl.glEndList() self.create_axes_list() self._refresh_display_lists = False def randomize_features(self): '''Randomizes data features displayed using current display mode.''' ids = list(range(self.data.shape[2])) if self.quadrant_mode == 'single': features = random_subset(ids, 3) elif self.quadrant_mode == 'mirrored': features = random_subset(ids, 6) else: features = [random_subset(ids, 3) for i in range(8)] print('New feature IDs:') pprint(np.array(features)) self.set_octant_display_features(features) def set_features(self, features, mode='single'): if mode == 'single': if len(features) != 3: raise Exception( 'Expected 3 feature indices for "single" mode.') elif mode == 'mirrored': if len(features) != 6: raise Exception( 'Expected 6 feature indices for "mirrored" mode.') elif mode == 'independent': if len(features) != 8: raise Exception('Expected 8 3-tuples of feature indices for' '"independent" mode.') else: raise Exception('Unrecognized feature mode: %s.' % str(mode)) print('New feature IDs:') pprint(np.array(features)) self.set_octant_display_features(features) self.Refresh() def draw_box(self, x0, y0, x1, y1): '''Draws a selection box in the 3-D window. Coordinates are with respect to the lower left corner of the window. ''' import OpenGL.GL as gl gl.glMatrixMode(gl.GL_PROJECTION) gl.glLoadIdentity() gl.glOrtho(0.0, self.size[0], 0.0, self.size[1], -0.01, 10.0) gl.glLineStipple(1, 0xF00F) gl.glEnable(gl.GL_LINE_STIPPLE) gl.glLineWidth(1.0) gl.glColor3f(1.0, 1.0, 1.0) gl.glBegin(gl.GL_LINE_LOOP) gl.glVertex3f(x0, y0, 0.0) gl.glVertex3f(x1, y0, 0.0) gl.glVertex3f(x1, y1, 0.0) gl.glVertex3f(x0, y1, 0.0) gl.glEnd() gl.glDisable(gl.GL_LINE_STIPPLE) gl.glFlush() self.resize(*self.size) def on_paint(self, event): '''Renders the entire scene.''' import OpenGL.GL as gl import OpenGL.GLU as glu self.canvas.SetCurrent(self.canvas.context) if not self.gl_initialized: self.initgl() self.gl_initialized = True self.print_help() self.resize(*self.size) gl.glMatrixMode(gl.GL_MODELVIEW) gl.glLoadIdentity() gl.glClear(gl.GL_COLOR_BUFFER_BIT | gl.GL_DEPTH_BUFFER_BIT) while len(self._display_commands) > 0: self._display_commands.pop(0)() if self._refresh_display_lists: self.create_display_lists() gl.glPushMatrix() # camera_pos_rtp is relative to target position. To get the absolute # camera position, we need to add the target position. camera_pos_xyz = np.array(rtp_to_xyz(*self.camera_pos_rtp)) \ + self.target_pos glu.gluLookAt(*(list(camera_pos_xyz) + list(self.target_pos) + self.up)) if self.show_axes_tf: gl.glCallList(self.gllist_id) self.draw_data_set() gl.glPopMatrix() gl.glFlush() if self._selection_box is not None: self.draw_box(*self._selection_box) self.SwapBuffers() event.Skip() def post_reassign_selection(self, new_class): '''Reassigns pixels in selection box during the next rendering loop. ARGUMENT: `new_class` (int): The class to which the pixels in the box will be assigned. ''' if self._selection_box is None: msg = 'Bounding box is not selected. Hold SHIFT and click & ' + \ 'drag with the left\nmouse button to select a region.' print(msg) return 0 self.add_display_command(lambda: self.reassign_selection(new_class)) self.canvas.Refresh() return 0 def reassign_selection(self, new_class): '''Reassigns pixels in the selection box to the specified class. This method should only be called from the `display` method. Pixels are reassigned by identifying each pixel in the 3D display by their unique color, then reassigning them. Since pixels can block others in the z-buffer, this method iteratively reassigns pixels by removing any reassigned pixels from the display list, then reassigning again, repeating until there are no more pixels in the selction box. ''' nreassigned_tot = 0 i = 1 print('Reassigning points', end=' ') while True: indices = np.array(self._display_indices) classes = np.array(self.classes.ravel()[indices]) indices = indices[np.where(classes != new_class)] ids = self.get_points_in_selection_box(indices=indices) cr = self.classes.ravel() nreassigned = np.sum(cr[ids] != new_class) nreassigned_tot += nreassigned cr[ids] = new_class new_color = np.zeros(4, 'uint8') new_color[:3] = (np.array(self.palette[new_class]) * 255).astype('uint8') self.colors.reshape((-1, 4))[ids] = new_color self.create_display_lists() if len(ids) == 0: break # print 'Pass %d: %d points reassigned to class %d.' \ # % (i, nreassigned, new_class) print('.', end=' ') i += 1 print('\n%d points were reasssigned to class %d.' \ % (nreassigned_tot, new_class)) self._selection_box = None if nreassigned_tot > 0 and new_class == self.max_menu_class: self.max_menu_class += 1 if nreassigned_tot > 0: event = SpyMplEvent('spy_classes_modified') event.classes = self.classes event.nchanged = nreassigned_tot self.callbacks.process('spy_classes_modified', event) return nreassigned_tot def get_points_in_selection_box(self, **kwargs): '''Returns pixel IDs of all points in the current selection box. KEYWORD ARGS: `indices` (ndarray of ints): An alternate set of N-D image pixels to display. Pixels are identified by performing a background rendering loop wherein each pixel is rendered with a unique color. Then, glReadPixels is used to read colors of pixels in the current selection box. ''' import OpenGL.GL as gl indices = kwargs.get('indices', None) point_size_temp = self.point_size self.point_size = kwargs.get('point_size', 1) xsize = self._selection_box[2] - self._selection_box[0] + 1 ysize = self._selection_box[3] - self._selection_box[1] + 1 ids = np.zeros(xsize * ysize, int) self.create_display_lists(0, indices=indices) self.render_rgb_indexed_colors() gl.glPixelStorei(gl.GL_UNPACK_ALIGNMENT, 1) pixels = gl.glReadPixelsub(self._selection_box[0], self._selection_box[1], xsize, ysize, gl.GL_RGBA) pixels = np.frombuffer(pixels, dtype=np.uint8).reshape( (ysize, xsize, 4)) for i in range(4): component = pixels[:, :, i].reshape((xsize * ysize,)) \ & self._rgba_masks[i] shift = (sum(self._rgba_bits[0:i]) - self._low_bits[i]) if shift > 0: ids += component.astype(int) << shift else: ids += component.astype(int) >> (-shift) points = ids[ids > 0] self.point_size = point_size_temp gl.glMatrixMode(gl.GL_MODELVIEW) gl.glLoadIdentity() gl.glClear(gl.GL_COLOR_BUFFER_BIT | gl.GL_DEPTH_BUFFER_BIT) self._refresh_display_lists = True return points def get_pixel_info(self, x, y, **kwargs): '''Prints row/col of the pixel at the given raster position. ARGUMENTS: `x`, `y`: (int): The pixel's coordinates relative to the lower left corner. ''' self._selection_box = (x, y, x, y) ids = self.get_points_in_selection_box(point_size=self.point_size) for id in ids: if id > 0: rc = self.index_to_image_row_col(id) print('Pixel %d %s has class %s.' % (id, rc, self.classes[rc])) return def render_rgb_indexed_colors(self, **kwargs): '''Draws scene in the background buffer to extract mouse click info''' import OpenGL.GL as gl import OpenGL.GLU as glu gl.glMatrixMode(gl.GL_MODELVIEW) gl.glLoadIdentity() gl.glClear(gl.GL_COLOR_BUFFER_BIT | gl.GL_DEPTH_BUFFER_BIT) # camera_pos_rtp is relative to the target position. To get the # absolute camera position, we need to add the target position. gl.glPushMatrix() camera_pos_xyz = np.array(rtp_to_xyz(*self.camera_pos_rtp)) \ + self.target_pos glu.gluLookAt(*(list(camera_pos_xyz) + list(self.target_pos) + self.up)) self.draw_data_set() gl.glPopMatrix() gl.glFlush() def index_to_image_row_col(self, index): '''Converts the unraveled pixel ID to row/col of the N-D image.''' rowcol = (index // self.data.shape[1], index % self.data.shape[1]) return rowcol def draw_data_set(self): '''Draws the N-D data set in the scene.''' import OpenGL.GL as gl for i in range(1, 9): gl.glCallList(self.gllist_id + i) def create_axes_list(self): '''Creates display lists to render unit length x,y,z axes.''' import OpenGL.GL as gl gl.glNewList(self.gllist_id, gl.GL_COMPILE) gl.glBegin(gl.GL_LINES) gl.glColor3f(1.0, 0.0, 0.0) gl.glVertex3f(0.0, 0.0, 0.0) gl.glVertex3f(1.0, 0.0, 0.0) gl.glColor3f(0.0, 1.0, 0.0) gl.glVertex3f(0.0, 0.0, 0.0) gl.glVertex3f(0.0, 1.0, 0.0) gl.glColor3f(-.0, 0.0, 1.0) gl.glVertex3f(0.0, 0.0, 0.0) gl.glVertex3f(0.0, 0.0, 1.0) gl.glColor3f(1.0, 1.0, 1.0) gl.glVertex3f(0.0, 0.0, 0.0) gl.glVertex3f(-1.0, 0.0, 0.0) gl.glVertex3f(0.0, 0.0, 0.0) gl.glVertex3f(0.0, -1.0, 0.0) gl.glVertex3f(0.0, 0.0, 0.0) gl.glVertex3f(0.0, 0.0, -1.0) gl.glEnd() def label_axis(x, y, z, label): gl.glRasterPos3f(x, y, z) glut.glutBitmapString(glut.GLUT_BITMAP_HELVETICA_18, str(label)) def label_axis_for_feature(x, y, z, feature_ind): feature = self.octant_features[feature_ind[0]][feature_ind[1]] label_axis(x, y, z, self.labels[feature]) if self._have_glut: try: import OpenGL.GLUT as glut if bool(glut.glutBitmapString): if self.quadrant_mode == 'independent': label_axis(1.05, 0.0, 0.0, 'x') label_axis(0.0, 1.05, 0.0, 'y') label_axis(0.0, 0.0, 1.05, 'z') elif self.quadrant_mode == 'mirrored': label_axis_for_feature(1.05, 0.0, 0.0, (0, 0)) label_axis_for_feature(0.0, 1.05, 0.0, (0, 1)) label_axis_for_feature(0.0, 0.0, 1.05, (0, 2)) label_axis_for_feature(-1.05, 0.0, 0.0, (6, 0)) label_axis_for_feature(0.0, -1.05, 0.0, (6, 1)) label_axis_for_feature(0.0, 0.0, -1.05, (6, 2)) else: label_axis_for_feature(1.05, 0.0, 0.0, (0, 0)) label_axis_for_feature(0.0, 1.05, 0.0, (0, 1)) label_axis_for_feature(0.0, 0.0, 1.05, (0, 2)) except: pass gl.glEndList() def GetGLExtents(self): """Get the extents of the OpenGL canvas.""" return def SwapBuffers(self): """Swap the OpenGL buffers.""" self.canvas.SwapBuffers() def on_erase_background(self, event): """Process the erase background event.""" pass # Do nothing, to avoid flashing on MSWin def initgl(self): '''App-specific initialization for after GLUT has been initialized.''' import OpenGL.GL as gl self.gllist_id = gl.glGenLists(9) gl.glEnableClientState(gl.GL_VERTEX_ARRAY) gl.glEnableClientState(gl.GL_COLOR_ARRAY) gl.glDisable(gl.GL_LIGHTING) gl.glDisable(gl.GL_TEXTURE_2D) gl.glDisable(gl.GL_FOG) gl.glDisable(gl.GL_COLOR_MATERIAL) gl.glEnable(gl.GL_DEPTH_TEST) gl.glShadeModel(gl.GL_FLAT) self.set_data(self.data, classes=self.classes, features=self.features) try: import OpenGL.GLUT as glut glut.glutInit() self._have_glut = True except: pass def on_resize(self, event): '''Process the resize event.''' # For wx versions 2.9.x, GLCanvas.GetContext() always returns None, # whereas 2.8.x will return the context so test for both versions. if wx.VERSION >= (2, 9) or self.canvas.GetContext(): self.canvas.SetCurrent(self.canvas.context) # Make sure the frame is shown before calling SetCurrent. self.Show() size = event.GetSize() self.resize(size.width, size.height) self.canvas.Refresh(False) event.Skip() def resize(self, width, height): """Reshape the OpenGL viewport based on dimensions of the window.""" import OpenGL.GL as gl import OpenGL.GLU as glu self.size = (width, height) gl.glViewport(0, 0, width, height) gl.glMatrixMode(gl.GL_PROJECTION) gl.glLoadIdentity() glu.gluPerspective(self.fovy, float(width) / height, self.znear, self.zfar) gl.glMatrixMode(gl.GL_MODELVIEW) gl.glLoadIdentity() def on_char(self, event): '''Callback function for when a keyboard button is pressed.''' key = chr(event.GetKeyCode()) # See `print_help` method for explanation of keybinds. if key == 'a': self.show_axes_tf = not self.show_axes_tf elif key == 'c': self.view_class_image() elif key == 'd': if self.data.shape[2] < 6: print('Only single-quadrant mode is supported for %d features.' % \ self.data.shape[2]) return if self.quadrant_mode == 'single': self.quadrant_mode = 'mirrored' elif self.quadrant_mode == 'mirrored': self.quadrant_mode = 'independent' else: self.quadrant_mode = 'single' print('Setting quadrant display mode to %s.' % self.quadrant_mode) self.randomize_features() elif key == 'f': self.randomize_features() elif key == 'h': self.print_help() elif key == 'm': self.mouse_panning = not self.mouse_panning elif key == 'p': self.point_size += 1 self._refresh_display_lists = True elif key == 'P': self.point_size = max(self.point_size - 1, 1.0) self._refresh_display_lists = True elif key == 'q': self.on_event_close() self.Close(True) elif key == 'r': self.reset_view_geometry() elif key == 'u': self._show_unassigned = not self._show_unassigned print('SHOW UNASSIGNED =', self._show_unassigned) self._refresh_display_lists = True self.canvas.Refresh() def update_window_title(self): '''Prints current file name and current point color to window title.''' s = 'SPy N-D Data Set' glutSetWindowTitle(s) def get_proxy(self): '''Returns a proxy object to access data from the window.''' return NDWindowProxy(self) def view_class_image(self, *args, **kwargs): '''Opens a dynamic raster image of class values. The class IDs displayed are those currently associated with the ND window. `args` and `kwargs` are additional arguments passed on to the `ImageView` constructor. Return value is the ImageView object. ''' view = ImageView(classes=self.classes, *args, **kwargs) view.callbacks_common = self.callbacks view.show() return view def print_help(self): '''Prints a list of accepted keyboard/mouse inputs.''' print('''Mouse functions: --------------- Left-click & drag --> Rotate viewing geometry (or pan) CTRL+Left-click & drag --> Zoom viewing geometry CTRL+SHIFT+Left-click --> Print image row/col and class of selected pixel SHIFT+Left-click & drag --> Define selection box in the window Right-click --> Open GLUT menu for pixel reassignment Keyboard functions: ------------------- a --> Toggle axis display c --> View dynamic raster image of class values d --> Cycle display mode between single-quadrant, mirrored octants, and independent octants (display will not change until features are randomzed again) f --> Randomize features displayed h --> Print this help message m --> Toggle mouse function between rotate/zoom and pan modes p/P --> Increase/Decrease the size of displayed points q --> Exit the application r --> Reset viewing geometry u --> Toggle display of unassigned points (points with class == 0) ''')
def __init__(self, data, parent, id, *args, **kwargs): global DEFAULT_WIN_SIZE self.kwargs = kwargs self.size = kwargs.get('size', DEFAULT_WIN_SIZE) self.title = kwargs.get('title', 'ND Window') # # Forcing a specific style on the window. # Should this include styles passed? style = wx.DEFAULT_FRAME_STYLE | wx.NO_FULL_REPAINT_ON_RESIZE super(NDWindow, self).__init__(parent, id, self.title, wx.DefaultPosition, wx.Size(*self.size), style, self.title) self.gl_initialized = False attribs = (glcanvas.WX_GL_RGBA, glcanvas.WX_GL_DOUBLEBUFFER, glcanvas.WX_GL_DEPTH_SIZE, settings.WX_GL_DEPTH_SIZE) self.canvas = glcanvas.GLCanvas(self, attribList=attribs) self.canvas.context = wx.glcanvas.GLContext(self.canvas) self._have_glut = False self.clear_color = (0, 0, 0, 0) self.show_axes_tf = True self.point_size = 1.0 self._show_unassigned = True self._refresh_display_lists = False self._click_tolerance = 1 self._display_commands = [] self._selection_box = None self._rgba_indices = None self.mouse_panning = False self.win_pos = (100, 100) self.fovy = 60. self.znear = 0.1 self.zfar = 10.0 self.target_pos = [0.0, 0.0, 0.0] self.camera_pos_rtp = [7.0, 45.0, 30.0] self.up = [0.0, 0.0, 1.0] self.quadrant_mode = None self.mouse_handler = MouseHandler(self) # Set the event handlers. self.canvas.Bind(wx.EVT_ERASE_BACKGROUND, self.on_erase_background) self.Bind(wx.EVT_SIZE, self.on_resize) self.canvas.Bind(wx.EVT_PAINT, self.on_paint) self.canvas.Bind(wx.EVT_LEFT_DOWN, self.mouse_handler.left_down) self.canvas.Bind(wx.EVT_LEFT_UP, self.mouse_handler.left_up) self.canvas.Bind(wx.EVT_MOTION, self.mouse_handler.motion) self.canvas.Bind(wx.EVT_CHAR, self.on_char) self.canvas.Bind(wx.EVT_RIGHT_DOWN, self.right_click) self.canvas.Bind(wx.EVT_CLOSE, self.on_event_close) self.data = data self.classes = kwargs.get('classes', np.zeros(data.shape[:-1], np.int)) self.features = kwargs.get('features', list(range(6))) self.labels = kwargs.get('labels', list(range(data.shape[-1]))) self.max_menu_class = int(np.max(self.classes.ravel() + 1)) from matplotlib.cbook import CallbackRegistry self.callbacks = CallbackRegistry()
def __init__(self, data, parent, id, *args, **kwargs): from spectral import settings global DEFAULT_WIN_SIZE self.kwargs = kwargs self.size = kwargs.get('size', DEFAULT_WIN_SIZE) self.title = kwargs.get('title', 'ND Window') # # Forcing a specific style on the window. # Should this include styles passed? style = wx.DEFAULT_FRAME_STYLE | wx.NO_FULL_REPAINT_ON_RESIZE super(NDWindow, self).__init__(parent, id, self.title, wx.DefaultPosition, wx.Size(*self.size), style, self.title) self.gl_initialized = False attribs = (glcanvas.WX_GL_RGBA, glcanvas.WX_GL_DOUBLEBUFFER, glcanvas.WX_GL_DEPTH_SIZE, settings.WX_GL_DEPTH_SIZE) self.canvas = glcanvas.GLCanvas(self, attribList=attribs) self.canvas.context = wx.glcanvas.GLContext(self.canvas) self._have_glut = False self.clear_color = (0, 0, 0, 0) self.show_axes_tf = True self.point_size = 1.0 self._show_unassigned = True self._refresh_display_lists = False self._click_tolerance = 1 self._display_commands = [] self._selection_box = None self._rgba_indices = None self.mouse_panning = False self.win_pos = (100, 100) self.fovy = 60. self.znear = 0.1 self.zfar = 10.0 self.target_pos = [0.0, 0.0, 0.0] self.camera_pos_rtp = [7.0, 45.0, 30.0] self.up = [0.0, 0.0, 1.0] self.quadrant_mode = None self.mouse_handler = MouseHandler(self) # Set the event handlers. self.canvas.Bind(wx.EVT_ERASE_BACKGROUND, self.on_erase_background) self.Bind(wx.EVT_SIZE, self.on_resize) self.canvas.Bind(wx.EVT_PAINT, self.on_paint) self.canvas.Bind(wx.EVT_LEFT_DOWN, self.mouse_handler.left_down) self.canvas.Bind(wx.EVT_LEFT_UP, self.mouse_handler.left_up) self.canvas.Bind(wx.EVT_MOTION, self.mouse_handler.motion) self.canvas.Bind(wx.EVT_CHAR, self.on_char) self.canvas.Bind(wx.EVT_RIGHT_DOWN, self.right_click) self.canvas.Bind(wx.EVT_CLOSE, self.on_event_close) self.data = data self.classes = kwargs.get('classes', np.zeros(data.shape[:-1], np.int)) self.features = kwargs.get('features', list(range(6))) self.max_menu_class = int(np.max(self.classes.ravel() + 1)) from matplotlib.cbook import CallbackRegistry self.callbacks = CallbackRegistry()
class NDWindow(wx.Frame): '''A widow class for displaying N-dimensional data points.''' def __init__(self, data, parent, id, *args, **kwargs): from spectral import settings global DEFAULT_WIN_SIZE self.kwargs = kwargs self.size = kwargs.get('size', DEFAULT_WIN_SIZE) self.title = kwargs.get('title', 'ND Window') # # Forcing a specific style on the window. # Should this include styles passed? style = wx.DEFAULT_FRAME_STYLE | wx.NO_FULL_REPAINT_ON_RESIZE super(NDWindow, self).__init__(parent, id, self.title, wx.DefaultPosition, wx.Size(*self.size), style, self.title) self.gl_initialized = False attribs = (glcanvas.WX_GL_RGBA, glcanvas.WX_GL_DOUBLEBUFFER, glcanvas.WX_GL_DEPTH_SIZE, settings.WX_GL_DEPTH_SIZE) self.canvas = glcanvas.GLCanvas(self, attribList=attribs) self.canvas.context = wx.glcanvas.GLContext(self.canvas) self._have_glut = False self.clear_color = (0, 0, 0, 0) self.show_axes_tf = True self.point_size = 1.0 self._show_unassigned = True self._refresh_display_lists = False self._click_tolerance = 1 self._display_commands = [] self._selection_box = None self._rgba_indices = None self.mouse_panning = False self.win_pos = (100, 100) self.fovy = 60. self.znear = 0.1 self.zfar = 10.0 self.target_pos = [0.0, 0.0, 0.0] self.camera_pos_rtp = [7.0, 45.0, 30.0] self.up = [0.0, 0.0, 1.0] self.quadrant_mode = None self.mouse_handler = MouseHandler(self) # Set the event handlers. self.canvas.Bind(wx.EVT_ERASE_BACKGROUND, self.on_erase_background) self.Bind(wx.EVT_SIZE, self.on_resize) self.canvas.Bind(wx.EVT_PAINT, self.on_paint) self.canvas.Bind(wx.EVT_LEFT_DOWN, self.mouse_handler.left_down) self.canvas.Bind(wx.EVT_LEFT_UP, self.mouse_handler.left_up) self.canvas.Bind(wx.EVT_MOTION, self.mouse_handler.motion) self.canvas.Bind(wx.EVT_CHAR, self.on_char) self.canvas.Bind(wx.EVT_RIGHT_DOWN, self.right_click) self.canvas.Bind(wx.EVT_CLOSE, self.on_event_close) self.data = data self.classes = kwargs.get('classes', np.zeros(data.shape[:-1], np.int)) self.features = kwargs.get('features', list(range(6))) self.max_menu_class = int(np.max(self.classes.ravel() + 1)) from matplotlib.cbook import CallbackRegistry self.callbacks = CallbackRegistry() def on_event_close(self, event=None): pass def right_click(self, event): self.canvas.SetCurrent(self.canvas.context) self.canvas.PopupMenu(MouseMenu(self), event.GetPosition()) def add_display_command(self, cmd): '''Adds a command to be called next time `display` is run.''' self._display_commands.append(cmd) def reset_view_geometry(self): '''Sets viewing geometry to the default view.''' # All grid points will be adjusted to the range [0,1] so this # is a reasonable center coordinate for the scene self.target_pos = np.array([0.0, 0.0, 0.0]) # Specify the camera location in spherical polar coordinates relative # to target_pos. self.camera_pos_rtp = [2.5, 45.0, 30.0] def set_data(self, data, **kwargs): '''Associates N-D point data with the window. ARGUMENTS: data (numpy.ndarray): An RxCxB array of data points to display. KEYWORD ARGUMENTS: classes (numpy.ndarray): An RxC array of integer class labels (zeros means unassigned). features (list): Indices of feautures to display in the octant (see NDWindow.set_octant_display_features for description). ''' import OpenGL.GL as gl try: from OpenGL.GL import glGetIntegerv except: from OpenGL.GL.glget import glGetIntegerv classes = kwargs.get('classes', None) features = kwargs.get('features', list(range(6))) if self.data.shape[2] < 6: features = features[:3] self.quadrant_mode == 'single' # Scale the data set to span an octant data2d = np.array(data.reshape((-1, data.shape[-1]))) mins = np.min(data2d, axis=0) maxes = np.max(data2d, axis=0) denom = (maxes - mins).astype(float) denom = np.where(denom > 0, denom, 1.0) self.data = (data2d - mins) / denom self.data.shape = data.shape self.palette = spy_colors.astype(float) / 255. self.palette[0] = np.array([1.0, 1.0, 1.0]) self.colors = self.palette[self.classes.ravel()].reshape( self.data.shape[:2] + (3,)) self.colors = (self.colors * 255).astype('uint8') colors = np.ones((self.colors.shape[:-1]) + (4,), 'uint8') colors[:, :, :-1] = self.colors self.colors = colors self._refresh_display_lists = True self.set_octant_display_features(features) # Determine the bit masks to use when using RGBA components for # identifying pixel IDs. components = [gl.GL_RED_BITS, gl.GL_GREEN_BITS, gl.GL_GREEN_BITS, gl.GL_ALPHA_BITS] self._rgba_bits = [min(8, glGetIntegerv(i)) for i in components] self._low_bits = [min(8, 8 - self._rgba_bits[i]) for i in range(4)] self._rgba_masks = \ [(2**self._rgba_bits[i] - 1) << (8 - self._rgba_bits[i]) for i in range(4)] # Determine how many times the scene will need to be rendered in the # background to extract the pixel's row/col index. N = self.data.shape[0] * self.data.shape[1] if N > 2**sum(self._rgba_bits): raise Exception('Insufficient color bits (%d) for N-D window display' % sum(self._rgba_bits)) self.reset_view_geometry() def set_octant_display_features(self, features): '''Specifies features to be displayed in each 3-D coordinate octant. `features` can be any of the following: A length-3 list of integer feature IDs: In this case, the data points will be displayed in the positive x,y,z octant using features associated with the 3 integers. A length-6 list if integer feature IDs: In this case, each integer specifies a single feature index to be associated with the coordinate semi-axes x, y, z, -x, -y, and -z (in that order). Each octant will display data points using the features associated with the 3 semi-axes for that octant. A length-8 list of length-3 lists of integers: In this case, each length-3 list specfies the features to be displayed in a single octants (the same semi-axis can be associated with different features in different octants). Octants are ordered starting with the postive x,y,z octant and procede counterclockwise around the z-axis, then procede similarly around the negative half of the z-axis. An octant triplet can be specified as None instead of a list, in which case nothing will be rendered in that octant. ''' if features is None: features = list(range(6)) if len(features) == 3: self.octant_features = [features] + [None] * 7 new_quadrant_mode = 'single' self.target_pos = np.array([0.5, 0.5, 0.5]) elif len(features) == 6: self.octant_features = create_mirrored_octants(features) new_quadrant_mode = 'mirrored' self.target_pos = np.array([0.0, 0.0, 0.0]) else: self.octant_features = features new_quadrant_mode = 'independent' self.target_pos = np.array([0.0, 0.0, 0.0]) if new_quadrant_mode != self.quadrant_mode: print('Setting quadrant display mode to %s.' % new_quadrant_mode) self.quadrant_mode = new_quadrant_mode self._refresh_display_lists = True def create_display_lists(self, npass=-1, **kwargs): '''Creates or updates the display lists for image data. ARGUMENTS: `npass` (int): When defaulted to -1, the normal image data display lists are created. When >=0, `npass` represents the rendering pass for identifying image pixels in the scene by their unique colors. KEYWORD ARGS: `indices` (list of ints): An optional list of N-D image pixels to display. ''' import OpenGL.GL as gl gl.glEnableClientState(gl.GL_COLOR_ARRAY) gl.glEnableClientState(gl.GL_VERTEX_ARRAY) gl.glPointSize(self.point_size) gl.glColorPointerub(self.colors) (R, C, B) = self.data.shape indices = kwargs.get('indices', None) if indices is None: indices = np.arange(R * C) if not self._show_unassigned: indices = indices[self.classes.ravel() != 0] self._display_indices = indices # RGB pixel indices for selecting pixels with the mouse gl.glPointSize(self.point_size) if npass < 0: # Colors are associated with image pixel classes. gl.glColorPointerub(self.colors) else: if self._rgba_indices is None: # Generate unique colors that correspond to each pixel's ID # so that the color can be used to identify the pixel. color_indices = np.arange(R * C) rgba = np.zeros((len(color_indices), 4), 'uint8') for i in range(4): shift = sum(self._rgba_bits[0:i]) - self._low_bits[i] if shift > 0: rgba[:, i] = ( color_indices >> shift) & self._rgba_masks[i] else: rgba[:, i] = (color_indices << self._low_bits[i]) \ & self._rgba_masks[i] self._rgba_indices = rgba gl.glColorPointerub(self._rgba_indices) # Generate a display list for each octant of the 3-D window. for (i, octant) in enumerate(self.octant_features): if octant is not None: data = np.take(self.data, octant, axis=2).reshape((-1, 3)) data *= octant_coeffs[i] gl.glVertexPointerf(data) gl.glNewList(self.gllist_id + i + 1, gl.GL_COMPILE) gl.glDrawElementsui(gl.GL_POINTS, indices) gl.glEndList() else: # Create an empty draw list gl.glNewList(self.gllist_id + i + 1, gl.GL_COMPILE) gl.glEndList() self.create_axes_list() self._refresh_display_lists = False def randomize_features(self): '''Randomizes data features displayed using current display mode.''' import random from pprint import pprint ids = list(range(self.data.shape[2])) if self.quadrant_mode == 'single': features = random_subset(ids, 3) elif self.quadrant_mode == 'mirrored': features = random_subset(ids, 6) else: features = [random_subset(ids, 3) for i in range(8)] print('New feature IDs:') pprint(np.array(features)) self.set_octant_display_features(features) def set_features(self, features, mode='single'): from pprint import pprint if mode == 'single': if len(features) != 3: raise Exception( 'Expected 3 feature indices for "single" mode.') elif mode == 'mirrored': if len(features) != 6: raise Exception( 'Expected 6 feature indices for "mirrored" mode.') elif mode == 'independent': if len(features) != 8: raise Exception('Expected 8 3-tuples of feature indices for' '"independent" mode.') else: raise Exception('Unrecognized feature mode: %s.' % str(mode)) print('New feature IDs:') pprint(np.array(features)) self.set_octant_display_features(features) self.Refresh() def draw_box(self, x0, y0, x1, y1): '''Draws a selection box in the 3-D window. Coordinates are with respect to the lower left corner of the window. ''' import OpenGL.GL as gl gl.glMatrixMode(gl.GL_PROJECTION) gl.glLoadIdentity() gl.glOrtho(0.0, self.size[0], 0.0, self.size[1], -0.01, 10.0) gl.glLineStipple(1, 0xF00F) gl.glEnable(gl.GL_LINE_STIPPLE) gl.glLineWidth(1.0) gl.glColor3f(1.0, 1.0, 1.0) gl.glBegin(gl.GL_LINE_LOOP) gl.glVertex3f(x0, y0, 0.0) gl.glVertex3f(x1, y0, 0.0) gl.glVertex3f(x1, y1, 0.0) gl.glVertex3f(x0, y1, 0.0) gl.glEnd() gl.glDisable(gl.GL_LINE_STIPPLE) gl.glFlush() self.resize(*self.size) def on_paint(self, event): '''Renders the entire scene.''' import time import OpenGL.GL as gl import OpenGL.GLU as glu self.canvas.SetCurrent(self.canvas.context) if not self.gl_initialized: self.initgl() self.gl_initialized = True self.print_help() self.resize(*self.size) gl.glMatrixMode(gl.GL_MODELVIEW) gl.glLoadIdentity() gl.glClear(gl.GL_COLOR_BUFFER_BIT | gl.GL_DEPTH_BUFFER_BIT) while len(self._display_commands) > 0: self._display_commands.pop(0)() if self._refresh_display_lists: self.create_display_lists() gl.glPushMatrix() # camera_pos_rtp is relative to target position. To get the absolute # camera position, we need to add the target position. camera_pos_xyz = np.array(rtp_to_xyz(*self.camera_pos_rtp)) \ + self.target_pos glu.gluLookAt( *(list(camera_pos_xyz) + list(self.target_pos) + self.up)) if self.show_axes_tf: gl.glCallList(self.gllist_id) self.draw_data_set() gl.glPopMatrix() gl.glFlush() if self._selection_box is not None: self.draw_box(*self._selection_box) self.SwapBuffers() event.Skip() def post_reassign_selection(self, new_class): '''Reassigns pixels in selection box during the next rendering loop. ARGUMENT: `new_class` (int): The class to which the pixels in the box will be assigned. ''' if self._selection_box is None: msg = 'Bounding box is not selected. Hold SHIFT and click & ' + \ 'drag with the left\nmouse button to select a region.' print(msg) return 0 self.add_display_command(lambda: self.reassign_selection(new_class)) self.canvas.Refresh() return 0 def reassign_selection(self, new_class): '''Reassigns pixels in the selection box to the specified class. This method should only be called from the `display` method. Pixels are reassigned by identifying each pixel in the 3D display by their unique color, then reassigning them. Since pixels can block others in the z-buffer, this method iteratively reassigns pixels by removing any reassigned pixels from the display list, then reassigning again, repeating until there are no more pixels in the selction box. ''' import spectral nreassigned_tot = 0 i = 1 print('Reassigning points', end=' ') while True: indices = np.array(self._display_indices) classes = np.array(self.classes.ravel()[indices]) indices = indices[np.where(classes != new_class)] ids = self.get_points_in_selection_box(indices=indices) cr = self.classes.ravel() nreassigned = np.sum(cr[ids] != new_class) nreassigned_tot += nreassigned cr[ids] = new_class new_color = np.zeros(4, 'uint8') new_color[:3] = (np.array(self.palette[new_class]) * 255).astype('uint8') self.colors.reshape((-1, 4))[ids] = new_color self.create_display_lists() if len(ids) == 0: break # print 'Pass %d: %d points reassigned to class %d.' \ # % (i, nreassigned, new_class) print('.', end=' ') i += 1 print('\n%d points were reasssigned to class %d.' \ % (nreassigned_tot, new_class)) self._selection_box = None if nreassigned_tot > 0 and new_class == self.max_menu_class: self.max_menu_class += 1 if nreassigned_tot > 0: from .spypylab import SpyMplEvent event = SpyMplEvent('spy_classes_modified') event.classes = self.classes event.nchanged = nreassigned_tot self.callbacks.process('spy_classes_modified', event) return nreassigned_tot def get_points_in_selection_box(self, **kwargs): '''Returns pixel IDs of all points in the current selection box. KEYWORD ARGS: `indices` (ndarray of ints): An alternate set of N-D image pixels to display. Pixels are identified by performing a background rendering loop wherein each pixel is rendered with a unique color. Then, glReadPixels is used to read colors of pixels in the current selection box. ''' import OpenGL.GL as gl indices = kwargs.get('indices', None) point_size_temp = self.point_size self.point_size = kwargs.get('point_size', 1) xsize = self._selection_box[2] - self._selection_box[0] + 1 ysize = self._selection_box[3] - self._selection_box[1] + 1 ids = np.zeros(xsize * ysize, int) self.create_display_lists(0, indices=indices) self.render_rgb_indexed_colors() gl.glPixelStorei(gl.GL_UNPACK_ALIGNMENT, 1) pixels = gl.glReadPixelsub(self._selection_box[0], self._selection_box[1], xsize, ysize, gl.GL_RGBA) pixels = np.frombuffer(pixels, dtype=np.uint8).reshape((ysize, xsize, 4)) for i in range(4): component = pixels[:, :, i].reshape((xsize * ysize,)) \ & self._rgba_masks[i] shift = (sum(self._rgba_bits[0:i]) - self._low_bits[i]) if shift > 0: ids += component.astype(int) << shift else: ids += component.astype(int) >> (-shift) points = ids[ids > 0] self.point_size = point_size_temp gl.glMatrixMode(gl.GL_MODELVIEW) gl.glLoadIdentity() gl.glClear(gl.GL_COLOR_BUFFER_BIT | gl.GL_DEPTH_BUFFER_BIT) self._refresh_display_lists = True return points def get_pixel_info(self, x, y, **kwargs): '''Prints row/col of the pixel at the given raster position. ARGUMENTS: `x`, `y`: (int): The pixel's coordinates relative to the lower left corner. ''' self._selection_box = (x, y, x, y) ids = self.get_points_in_selection_box(point_size=self.point_size) for id in ids: if id > 0: rc = self.index_to_image_row_col(id) print('Pixel %d %s has class %s.' % (id, rc, self.classes[rc])) return def render_rgb_indexed_colors(self, **kwargs): '''Draws scene in the background buffer to extract mouse click info''' import OpenGL.GL as gl import OpenGL.GLU as glu gl.glMatrixMode(gl.GL_MODELVIEW) gl.glLoadIdentity() gl.glClear(gl.GL_COLOR_BUFFER_BIT | gl.GL_DEPTH_BUFFER_BIT) # camera_pos_rtp is relative to the target position. To get the # absolute camera position, we need to add the target position. gl.glPushMatrix() camera_pos_xyz = np.array(rtp_to_xyz(*self.camera_pos_rtp)) \ + self.target_pos glu.gluLookAt( *(list(camera_pos_xyz) + list(self.target_pos) + self.up)) self.draw_data_set() gl.glPopMatrix() gl.glFlush() def index_to_image_row_col(self, index): '''Converts the unraveled pixel ID to row/col of the N-D image.''' rowcol = (index / self.data.shape[1], index % self.data.shape[1]) return rowcol def draw_data_set(self): '''Draws the N-D data set in the scene.''' import OpenGL.GL as gl for i in range(1, 9): gl.glCallList(self.gllist_id + i) def create_axes_list(self): '''Creates display lists to render unit length x,y,z axes.''' import OpenGL.GL as gl gl.glNewList(self.gllist_id, gl.GL_COMPILE) gl.glBegin(gl.GL_LINES) gl.glColor3f(1.0, 0.0, 0.0) gl.glVertex3f(0.0, 0.0, 0.0) gl.glVertex3f(1.0, 0.0, 0.0) gl.glColor3f(0.0, 1.0, 0.0) gl.glVertex3f(0.0, 0.0, 0.0) gl.glVertex3f(0.0, 1.0, 0.0) gl.glColor3f(-.0, 0.0, 1.0) gl.glVertex3f(0.0, 0.0, 0.0) gl.glVertex3f(0.0, 0.0, 1.0) gl.glColor3f(1.0, 1.0, 1.0) gl.glVertex3f(0.0, 0.0, 0.0) gl.glVertex3f(-1.0, 0.0, 0.0) gl.glVertex3f(0.0, 0.0, 0.0) gl.glVertex3f(0.0, -1.0, 0.0) gl.glVertex3f(0.0, 0.0, 0.0) gl.glVertex3f(0.0, 0.0, -1.0) gl.glEnd() if self._have_glut: try: import OpenGL.GLUT as glut if bool(glut.glutBitmapCharacter): gl.glRasterPos3f(1.05, 0.0, 0.0) glut.glutBitmapCharacter(glut.GLUT_BITMAP_HELVETICA_18, ord('x')) gl.glRasterPos3f(0.0, 1.05, 0.0) glut.glutBitmapCharacter(glut.GLUT_BITMAP_HELVETICA_18, ord('y')) gl.glRasterPos3f(0.0, 0.0, 1.05) glut.glutBitmapCharacter(glut.GLUT_BITMAP_HELVETICA_18, ord('z')) except: pass gl.glEndList() def GetGLExtents(self): """Get the extents of the OpenGL canvas.""" return def SwapBuffers(self): """Swap the OpenGL buffers.""" self.canvas.SwapBuffers() def on_erase_background(self, event): """Process the erase background event.""" pass # Do nothing, to avoid flashing on MSWin def initgl(self): '''App-specific initialization for after GLUT has been initialized.''' import OpenGL.GL as gl self.gllist_id = gl.glGenLists(9) gl.glEnableClientState(gl.GL_VERTEX_ARRAY) gl.glEnableClientState(gl.GL_COLOR_ARRAY) gl.glDisable(gl.GL_LIGHTING) gl.glDisable(gl.GL_TEXTURE_2D) gl.glDisable(gl.GL_FOG) gl.glDisable(gl.GL_COLOR_MATERIAL) gl.glEnable(gl.GL_DEPTH_TEST) gl.glShadeModel(gl.GL_FLAT) self.set_data(self.data, classes=self.classes, features=self.features) try: import OpenGL.GLUT as glut glut.glutInit() self._have_glut = True except: pass def on_resize(self, event): '''Process the resize event.''' # For wx versions 2.9.x, GLCanvas.GetContext() always returns None, # whereas 2.8.x will return the context so test for both versions. if wx.VERSION >= (2, 9) or self.canvas.GetContext(): self.canvas.SetCurrent(self.canvas.context) # Make sure the frame is shown before calling SetCurrent. self.Show() size = event.GetSize() self.resize(size.width, size.height) self.canvas.Refresh(False) event.Skip() def resize(self, width, height): """Reshape the OpenGL viewport based on dimensions of the window.""" import OpenGL.GL as gl import OpenGL.GLU as glu self.size = (width, height) gl.glViewport(0, 0, width, height) gl.glMatrixMode(gl.GL_PROJECTION) gl.glLoadIdentity() glu.gluPerspective(self.fovy, float(width) / height, self.znear, self.zfar) gl.glMatrixMode(gl.GL_MODELVIEW) gl.glLoadIdentity() def on_char(self, event): '''Callback function for when a keyboard button is pressed.''' key = chr(event.GetKeyCode()) # See `print_help` method for explanation of keybinds. if key == 'a': self.show_axes_tf = not self.show_axes_tf elif key == 'c': self.view_class_image() elif key == 'd': if self.data.shape[2] < 6: print('Only single-quadrant mode is supported for %d features.' % \ self.data.shape[2]) return if self.quadrant_mode == 'single': self.quadrant_mode = 'mirrored' elif self.quadrant_mode == 'mirrored': self.quadrant_mode = 'independent' else: self.quadrant_mode = 'single' print('Setting quadrant display mode to %s.' % self.quadrant_mode) self.randomize_features() elif key == 'f': self.randomize_features() elif key == 'h': self.print_help() elif key == 'm': self.mouse_panning = not self.mouse_panning elif key == 'p': self.point_size += 1 self._refresh_display_lists = True elif key == 'P': self.point_size = max(self.point_size - 1, 1.0) self._refresh_display_lists = True elif key == 'q': self.on_event_close() self.Close(True) elif key == 'r': self.reset_view_geometry() elif key == 'u': self._show_unassigned = not self._show_unassigned print('SHOW UNASSIGNED =', self._show_unassigned) self._refresh_display_lists = True self.canvas.Refresh() def update_window_title(self): '''Prints current file name and current point color to window title.''' s = 'SPy N-D Data Set' glutSetWindowTitle(s) def get_proxy(self): '''Returns a proxy object to access data from the window.''' return NDWindowProxy(self) def view_class_image(self, *args, **kwargs): '''Opens a dynamic raster image of class values. The class IDs displayed are those currently associated with the ND window. `args` and `kwargs` are additional arguments passed on to the `ImageView` constructor. Return value is the ImageView object. ''' from .spypylab import ImageView, MplCallback view = ImageView(classes=self.classes, *args, **kwargs) view.callbacks_common = self.callbacks view.show() return view def print_help(self): '''Prints a list of accepted keyboard/mouse inputs.''' import os print('''Mouse functions: --------------- Left-click & drag --> Rotate viewing geometry (or pan) CTRL+Left-click & drag --> Zoom viewing geometry CTRL+SHIFT+Left-click --> Print image row/col and class of selected pixel SHIFT+Left-click & drag --> Define selection box in the window Right-click --> Open GLUT menu for pixel reassignment Keyboard functions: ------------------- a --> Toggle axis display c --> View dynamic raster image of class values d --> Cycle display mode between single-quadrant, mirrored octants, and independent octants (display will not change until features are randomzed again) f --> Randomize features displayed h --> Print this help message m --> Toggle mouse function between rotate/zoom and pan modes p/P --> Increase/Decrease the size of displayed points q --> Exit the application r --> Reset viewing geometry u --> Toggle display of unassigned points (points with class == 0) ''')
def __init__(self, artists, *, multiple=False, highlight=False, hover=False, bindings=default_bindings): """Construct a cursor. Parameters ---------- artists : List[Artist] A list of artists that can be selected by this cursor. multiple : bool Whether multiple artists can be "on" at the same time (defaults to False). highlight : bool Whether to also highlight the selected artist. If so, "highlighter" artists will be placed as the first item in the :attr:`extras` attribute of the `Selection`. bindings : dict A mapping of button and keybindings to actions. Valid entries are: =================== =============================================== 'select' mouse button to select an artist (default: 1) 'deselect' mouse button to deselect an artist (default: 3) 'left' move to the previous point in the selected path, or to the left in the selected image (default: shift+left) 'right' move to the next point in the selected path, or to the right in the selected image (default: shift+right) 'up' move up in the selected image (default: shift+up) 'down' move down in the selected image (default: shift+down) 'toggle_visibility' toggle visibility of all cursors (default: d) 'toggle_enabled' toggle whether the cursor is active (default: t) =================== =============================================== hover : bool Whether to select artists upon hovering instead of by clicking. """ artists = list(artists) # Be careful with GC. self._artists = [weakref.ref(artist) for artist in artists] for artist in artists: type(self)._keep_alive.setdefault(artist, []).append(self) self._multiple = multiple self._highlight = highlight self._axes = {artist.axes for artist in artists} self._enabled = True self._selections = [] self._callbacks = CallbackRegistry() connect_pairs = [("key_press_event", self._on_key_press)] if hover: if multiple: raise ValueError("`hover` and `multiple` are incompatible") connect_pairs += [ ("motion_notify_event", self._on_select_button_press)] else: connect_pairs += [ ("button_press_event", self._on_button_press)] self._disconnect_cids = [ partial(canvas.mpl_disconnect, canvas.mpl_connect(*pair)) for pair in connect_pairs for canvas in {artist.figure.canvas for artist in artists}] bindings = {**default_bindings, **bindings} if set(bindings) != set(default_bindings): raise ValueError("Unknown bindings") actually_bound = {k: v for k, v in bindings.items() if v is not None} if len(set(actually_bound.values())) != len(actually_bound): raise ValueError("Duplicate bindings") self._bindings = bindings
class Cursor: """A cursor for selecting artists on a matplotlib figure. """ _keep_alive = WeakKeyDictionary() def __init__(self, artists, *, multiple=False, highlight=False, hover=False, bindings=default_bindings): """Construct a cursor. Parameters ---------- artists : List[Artist] A list of artists that can be selected by this cursor. multiple : bool Whether multiple artists can be "on" at the same time (defaults to False). highlight : bool Whether to also highlight the selected artist. If so, "highlighter" artists will be placed as the first item in the :attr:`extras` attribute of the `Selection`. bindings : dict A mapping of button and keybindings to actions. Valid entries are: =================== =============================================== 'select' mouse button to select an artist (default: 1) 'deselect' mouse button to deselect an artist (default: 3) 'left' move to the previous point in the selected path, or to the left in the selected image (default: shift+left) 'right' move to the next point in the selected path, or to the right in the selected image (default: shift+right) 'up' move up in the selected image (default: shift+up) 'down' move down in the selected image (default: shift+down) 'toggle_visibility' toggle visibility of all cursors (default: d) 'toggle_enabled' toggle whether the cursor is active (default: t) =================== =============================================== hover : bool Whether to select artists upon hovering instead of by clicking. """ artists = list(artists) # Be careful with GC. self._artists = [weakref.ref(artist) for artist in artists] for artist in artists: type(self)._keep_alive.setdefault(artist, []).append(self) self._multiple = multiple self._highlight = highlight self._axes = {artist.axes for artist in artists} self._enabled = True self._selections = [] self._callbacks = CallbackRegistry() connect_pairs = [("key_press_event", self._on_key_press)] if hover: if multiple: raise ValueError("`hover` and `multiple` are incompatible") connect_pairs += [("motion_notify_event", self._on_select_button_press)] else: connect_pairs += [("button_press_event", self._on_button_press)] self._disconnect_cids = [ partial(canvas.mpl_disconnect, canvas.mpl_connect(*pair)) for pair in connect_pairs for canvas in {artist.figure.canvas for artist in artists} ] bindings = {**default_bindings, **bindings} if set(bindings) != set(default_bindings): raise ValueError("Unknown bindings") actually_bound = {k: v for k, v in bindings.items() if v is not None} if len(set(actually_bound.values())) != len(actually_bound): raise ValueError("Duplicate bindings") self._bindings = bindings @property def enabled(self): """Whether clicks are registered for picking and unpicking events. """ return self._enabled @enabled.setter def enabled(self, value): self._enabled = value @property def artists(self): """The tuple of selectable artists. """ return tuple(filter(None, (ref() for ref in self._artists))) @property def selections(self): """The tuple of current `Selection`\\s. """ return tuple(self._selections) def add_selection(self, pi): """Create an annotation for a `Selection` and register it. Returns a new `Selection`, that has been registered by the `Cursor`, with the added annotation set in the :attr:`annotation` field and, if applicable, the highlighting artist in the :attr:`extras` field. Emits the ``"add"`` event with the new `Selection` as argument. """ # pi: "pick_info", i.e. an incomplete selection. ann = pi.artist.axes.annotate(_pick_info.get_ann_text(*pi), xy=pi.target, **default_annotation_kwargs) ann.draggable(use_blit=True) extras = [] if self._highlight: extras.append(self.add_highlight(pi.artist)) if not self._multiple: while self._selections: self._remove_selection(self._selections[-1]) sel = pi._replace(annotation=ann, extras=extras) self._selections.append(sel) self._callbacks.process("add", sel) sel.artist.figure.canvas.draw_idle() return sel def add_highlight(self, artist): """Create, add and return a highlighting artist. It is up to the caller to register the artist with the proper `Selection` in order to ensure cleanup upon deselection. """ hl = copy.copy(artist) hl.set(**default_highlight_kwargs) artist.axes.add_artist(hl) return hl def connect(self, event, func=None): """Connect a callback to a `Cursor` event; return the callback id. Two classes of event can be emitted, both with a `Selection` as single argument: - ``"add"`` when a `Selection` is added, and - ``"remove"`` when a `Selection` is removed. The callback registry relies on :mod:`matplotlib`'s implementation; in particular, only weak references are kept for bound methods. This method is can also be used as a decorator:: @cursor.connect("add") def on_add(sel): ... """ if event not in ["add", "remove"]: raise ValueError("Invalid cursor event: {}".format(event)) if func is None: return partial(self.connect, event) return self._callbacks.connect(event, func) def disconnect(self, cid): """Disconnect a previously connected callback id. """ self._callbacks.disconnect(cid) def remove(self): """Remove all `Selection`\\s and disconnect all callbacks. """ for disconnect_cid in self._disconnect_cids: disconnect_cid() while self._selections: self._remove_selection(self._selections[-1]) def _on_button_press(self, event): if event.button == self._bindings["select"]: self._on_select_button_press(event) if event.button == self._bindings["deselect"]: self._on_deselect_button_press(event) def _filter_mouse_event(self, event): # Accept the event iff we are enabled, and either # - no other widget is active, and this is not the second click of a # double click (to prevent double selection), or # - another widget is active, and this is a double click (to bypass # the widget lock). return (self.enabled and event.canvas.widgetlock.locked() == event.dblclick) def _on_select_button_press(self, event): if not self._filter_mouse_event(event): return # Work around lack of support for twinned axes. per_axes_event = { ax: _reassigned_axes_event(event, ax) for ax in self._axes } pis = [] for artist in self.artists: if (artist.axes is None # Removed or figure-level artist. or event.canvas is not artist.figure.canvas or not artist.axes.contains(event)[0]): # Cropped by axes. continue pi = _pick_info.compute_pick(artist, per_axes_event[artist.axes]) if pi: pis.append(pi) if not pis: return self.add_selection(min(pis, key=lambda pi: pi.dist)) def _on_deselect_button_press(self, event): if not self._filter_mouse_event(event): return for sel in self._selections: ann = sel.annotation if event.canvas is not ann.figure.canvas: continue contained, _ = ann.contains(event) if contained: self._remove_selection(sel) def _on_key_press(self, event): if event.key == self._bindings["toggle_enabled"]: self.enabled = not self.enabled elif event.key == self._bindings["toggle_visibility"]: for sel in self._selections: sel.annotation.set_visible(not sel.annotation.get_visible()) sel.annotation.figure.canvas.draw_idle() if self._selections: sel = self._selections[-1] else: return for key in ["left", "right", "up", "down"]: if event.key == self._bindings[key]: self._remove_selection(sel) self.add_selection(_pick_info.move(*sel, key=key)) break def _remove_selection(self, sel): self._selections.remove(sel) # Work around matplotlib/matplotlib#6785. draggable = sel.annotation._draggable try: draggable.disconnect() sel.annotation.figure.canvas.mpl_disconnect( sel.annotation._draggable._c1) except AttributeError: pass # (end of workaround). # <artist>.figure will be unset so we save them first. figures = {artist.figure for artist in [sel.annotation, *sel.extras]} # ValueError is raised if the artist has already been removed. with suppress(ValueError): sel.annotation.remove() for artist in sel.extras: with suppress(ValueError): artist.remove() self._callbacks.process("remove", sel) for figure in figures: figure.canvas.draw_idle()
""" cbook即为cookbook,是一些小工具组成的库 """ from matplotlib.cbook import CallbackRegistry callbacks = CallbackRegistry() sum = lambda x, y: print(f'{x}+{y}={x + y}') mul = lambda x, y: print(f"{x} * {y}={x * y}") id_sum = callbacks.connect("sum", sum) id_mul = callbacks.connect("mul", mul) callbacks.process('sum', 3, 4) callbacks.process("mul", 5, 6) callbacks.disconnect(id_sum) callbacks.process("sum", 7, 8)
def __init__(self, artists, *, multiple=False, highlight=False, hover=False, bindings=default_bindings): """Construct a cursor. Parameters ---------- artists : List[Artist] A list of artists that can be selected by this cursor. multiple : bool Whether multiple artists can be "on" at the same time (defaults to False). highlight : bool Whether to also highlight the selected artist. If so, "highlighter" artists will be placed as the first item in the :attr:`extras` attribute of the `Selection`. bindings : dict A mapping of button and keybindings to actions. Valid entries are: =================== =============================================== 'select' mouse button to select an artist (default: 1) 'deselect' mouse button to deselect an artist (default: 3) 'left' move to the previous point in the selected path, or to the left in the selected image (default: shift+left) 'right' move to the next point in the selected path, or to the right in the selected image (default: shift+right) 'up' move up in the selected image (default: shift+up) 'down' move down in the selected image (default: shift+down) 'toggle_visibility' toggle visibility of all cursors (default: d) 'toggle_enabled' toggle whether the cursor is active (default: t) =================== =============================================== hover : bool Whether to select artists upon hovering instead of by clicking. """ artists = list(artists) # Be careful with GC. self._artists = [weakref.ref(artist) for artist in artists] for artist in artists: type(self)._keep_alive.setdefault(artist, []).append(self) self._multiple = multiple self._highlight = highlight self._axes = {artist.axes for artist in artists} self._enabled = True self._selections = [] self._callbacks = CallbackRegistry() connect_pairs = [("key_press_event", self._on_key_press)] if hover: if multiple: raise ValueError("`hover` and `multiple` are incompatible") connect_pairs += [("motion_notify_event", self._on_select_button_press)] else: connect_pairs += [("button_press_event", self._on_button_press)] self._disconnect_cids = [ partial(canvas.mpl_disconnect, canvas.mpl_connect(*pair)) for pair in connect_pairs for canvas in {artist.figure.canvas for artist in artists} ] bindings = {**default_bindings, **bindings} if set(bindings) != set(default_bindings): raise ValueError("Unknown bindings") actually_bound = {k: v for k, v in bindings.items() if v is not None} if len(set(actually_bound.values())) != len(actually_bound): raise ValueError("Duplicate bindings") self._bindings = bindings
def __init__(self, artists, *, multiple=False, bindings=None, annotation_kwargs=None, annotation_positions=None): """Construct a cursor. Parameters ---------- artists : List[Artist] A list of artists that can be selected by this cursor. multiple : bool, optional Whether multiple artists can be "on" at the same time (defaults to False). bindings : dict, optional A mapping of button and keybindings to actions. Valid entries are: ================ ================================================== 'select' mouse button to select an artist (default: 1) 'deselect' mouse button to deselect an artist (default: 3) 'left' move to the previous point in the selected path, or to the left in the selected image (default: shift+left) 'right' move to the next point in the selected path, or to the right in the selected image (default: shift+right) 'up' move up in the selected image (default: shift+up) 'down' move down in the selected image (default: shift+down) 'toggle_enabled' toggle whether the cursor is active (default: e) 'toggle_visible' toggle default cursor visibility and apply it to all cursors (default: v) ================ ================================================== Missing entries will be set to the defaults. In order to not assign any binding to an action, set it to ``None``. annotation_kwargs : dict, optional Keyword argments passed to the `annotate <matplotlib.axes.Axes.annotate>` call. annotation_positions : List[dict], optional List of positions tried by the annotation positioning algorithm. """ self._artists = artists self._multiple = multiple self._visible = True self._enabled = True self._selections = [] self._last_auto_position = None self._last_active_selection = -1 self._callbacks = CallbackRegistry() connect_pairs = [ ('key_press_event', self._on_key_press), ('button_press_event', self._mouse_click_handler), ('pick_event', self._pick_event_handler) ] self._disconnectors = [ partial(canvas.mpl_disconnect, canvas.mpl_connect(*pair)) for pair in connect_pairs for canvas in {artist.figure.canvas for artist in self._artists} ] if bindings is not None: unknown_bindings = set(bindings) - set(_default_bindings) if unknown_bindings: raise ValueError("Unknown binding(s): {}".format(", ".join(sorted(unknown_bindings)))) duplicate_bindings = [k for k, v in Counter(list(bindings.values())).items() if v > 1] if duplicate_bindings: raise ValueError("Duplicate binding(s): {}".format(", ".join(sorted(map(str, duplicate_bindings))))) self.bindings = copy.deepcopy(_default_bindings) for key, value in bindings.items(): self.bindings[key] = value else: self.bindings = _default_bindings self.annotation_kwargs = copy.deepcopy(_default_annotation_kwargs) if annotation_kwargs is not None: for key, value in annotation_kwargs.items(): self.annotation_kwargs[key] = value self.annotation_positions = copy.deepcopy(_default_annotation_positions) if annotation_positions is not None: for key, value in annotation_positions.items(): self.annotation_positions[key] = value