class ImageView(object):
    '''Class to manage events and data associated with image raster views.

    In most cases, it is more convenient to simply call :func:`~spectral.graphics.spypylab.imshow`,
    which creates, displays, and returns an :class:`ImageView` object. Creating
    an :class:`ImageView` object directly (or creating an instance of a subclass)
    enables additional customization of the image display (e.g., overriding
    default event handlers). If the object is created directly, call the
    :meth:`show` method to display the image. The underlying image display
    functionality is implemented via :func:`matplotlib.pyplot.imshow`.
    '''
    selector_rectprops = dict(facecolor='red',
                              edgecolor='black',
                              alpha=0.5,
                              fill=True)
    selector_lineprops = dict(color='black',
                              linestyle='-',
                              linewidth=2,
                              alpha=0.5)

    def __init__(self,
                 data=None,
                 bands=None,
                 classes=None,
                 source=None,
                 **kwargs):
        '''
        Arguments:

            `data` (ndarray or :class:`SpyFile`):

                The source of RGB bands to be displayed. with shape (R, C, B).
                If the shape is (R, C, 3), the last dimension is assumed to
                provide the red, green, and blue bands (unless the `bands`
                argument is provided). If :math:`B > 3` and `bands` is not
                provided, the first, middle, and last band will be used.

            `bands` (triplet of integers):

                Specifies which bands in `data` should be displayed as red,
                green, and blue, respectively.

            `classes` (ndarray of integers):

                An array of integer-valued class labels with shape (R, C). If
                the `data` argument is provided, the shape must match the first
                two dimensions of `data`.

            `source` (ndarray or :class:`SpyFile`):

                The source of spectral data associated with the image display.
                This optional argument is used to access spectral data (e.g., to
                generate a spectrum plot when a user double-clicks on the image
                display.

        Keyword arguments:

            Any keyword that can be provided to :func:`~spectral.graphics.graphics.get_rgb`
            or :func:`matplotlib.imshow`.
        '''

        import spectral
        from spectral import settings
        self.is_shown = False
        self.imshow_data_kwargs = {'cmap': settings.imshow_float_cmap}
        self.rgb_kwargs = {}
        self.imshow_class_kwargs = {'zorder': 1}

        self.data = data
        self.data_rgb = None
        self.data_rgb_meta = {}
        self.classes = None
        self.class_rgb = None
        self.source = None
        self.bands = bands
        self.data_axes = None
        self.class_axes = None
        self.axes = None
        self._image_shape = None
        self.display_mode = None
        self._interpolation = None
        self.selection = None
        self.interpolation = kwargs.get('interpolation',
                                        settings.imshow_interpolation)

        if data is not None:
            self.set_data(data, bands, **kwargs)
        if classes is not None:
            self.set_classes(classes, **kwargs)
        if source is not None:
            self.set_source(source)

        self.class_colors = spectral.spy_colors

        self.spectrum_plot_fig_id = None
        self.parent = None
        self.selector = None
        self._on_parent_click_cid = None
        self._class_alpha = settings.imshow_class_alpha

        # Callbacks for events associated specifically with this window.
        self.callbacks = None

        # A sharable callback registry for related windows. If this
        # CallbackRegistry is set prior to calling ImageView.show (e.g., by
        # setting it equal to the `callbacks_common` member of another
        # ImageView object), then the registry will be shared. Otherwise, a new
        # callback registry will be created for this ImageView.
        self.callbacks_common = None

        check_disable_mpl_callbacks()

    def set_data(self, data, bands=None, **kwargs):
        '''Sets the data to be shown in the RGB channels.
        
        Arguments:

            `data` (ndarray or SpyImage):

                If `data` has more than 3 bands, the `bands` argument can be
                used to specify which 3 bands to display. `data` will be
                passed to `get_rgb` prior to display.

            `bands` (3-tuple of int):

                Indices of the 3 bands to display from `data`.

        Keyword Arguments:

            Any valid keyword for `get_rgb` or `matplotlib.imshow` can be
            given.
        '''
        from .graphics import _get_rgb_kwargs

        self.data = data
        self.bands = bands

        rgb_kwargs = {}
        for k in _get_rgb_kwargs:
            if k in kwargs:
                rgb_kwargs[k] = kwargs.pop(k)
        self.set_rgb_options(**rgb_kwargs)

        self._update_data_rgb()

        if self._image_shape is None:
            self._image_shape = data.shape[:2]
        elif data.shape[:2] != self._image_shape:
            raise ValueError('Image shape is inconsistent with previously ' \
                             'set data.')
        self.imshow_data_kwargs.update(kwargs)
        if 'interpolation' in self.imshow_data_kwargs:
            self.interpolation = self.imshow_data_kwargs['interpolation']
            self.imshow_data_kwargs.pop('interpolation')

        if len(kwargs) > 0 and self.is_shown:
            msg = 'Keyword args to set_data only have an effect if ' \
              'given before the image is shown.'
            warnings.warn(UserWarning(msg))
        if self.is_shown:
            self.refresh()

    def set_rgb_options(self, **kwargs):
        '''Sets parameters affecting RGB display of data.

        Accepts any keyword supported by :func:`~spectral.graphics.graphics.get_rgb`.
        '''
        from .graphics import _get_rgb_kwargs

        for k in kwargs:
            if k not in _get_rgb_kwargs:
                raise ValueError('Unexpected keyword: {0}'.format(k))
        self.rgb_kwargs = kwargs.copy()
        if self.is_shown:
            self._update_data_rgb()
            self.refresh()

    def _update_data_rgb(self):
        '''Regenerates the RGB values for display.'''
        from .graphics import get_rgb_meta

        (self.data_rgb, self.data_rgb_meta) = \
          get_rgb_meta(self.data, self.bands, **self.rgb_kwargs)

        # If it is a gray-scale image, only keep the first RGB component so
        # matplotlib imshow's cmap can still be used.
        if self.data_rgb_meta['mode'] == 'monochrome' and \
           self.data_rgb.ndim ==3:
            (self.bands is not None and len(self.bands) == 1)

    def set_classes(self, classes, colors=None, **kwargs):
        '''Sets the array of class values associated with the image data.

        Arguments:

            `classes` (ndarray of int):

                `classes` must be an integer-valued array with the same
                number rows and columns as the display data (if set).

            `colors`: (array or 3-tuples):

                Color triplets (with values in the range [0, 255]) that
                define the colors to be associatd with the integer indices
                in `classes`.

        Keyword Arguments:

            Any valid keyword for `matplotlib.imshow` can be provided.
        '''
        from .graphics import _get_rgb_kwargs
        self.classes = classes
        if classes is None:
            return
        if self._image_shape is None:
            self._image_shape = classes.shape[:2]
        elif classes.shape[:2] != self._image_shape:
            raise ValueError('Class data shape is inconsistent with ' \
                             'previously set data.')
        if colors is not None:
            self.class_colors = colors

        kwargs = dict([item for item in list(kwargs.items()) if item[0] not in \
                       _get_rgb_kwargs])
        self.imshow_class_kwargs.update(kwargs)

        if 'interpolation' in self.imshow_class_kwargs:
            self.interpolation = self.imshow_class_kwargs['interpolation']
            self.imshow_class_kwargs.pop('interpolation')

        if len(kwargs) > 0 and self.is_shown:
            msg = 'Keyword args to set_classes only have an effect if ' \
              'given before the image is shown.'
            warnings.warn(UserWarning(msg))
        if self.is_shown:
            self.refresh()

    def set_source(self, source):
        '''Sets the image data source (used for accessing spectral data).

        Arguments:

            `source` (ndarray or :class:`SpyFile`):

                The source for spectral data associated with the view.
        '''
        self.source = source

    def show(self, mode=None, fignum=None):
        '''Renders the image data.

        Arguments:

            `mode` (str):

                Must be one of:

                    "data":          Show the data RGB

                    "classes":       Shows indexed color for `classes`

                    "overlay":       Shows class colors overlaid on data RGB.

                If `mode` is not provided, a mode will be automatically
                selected, based on the data set in the ImageView.

            `fignum` (int):

                Figure number of the matplotlib figure in which to display
                the ImageView. If not provided, a new figure will be created.
        '''
        import matplotlib.pyplot as plt
        from spectral import settings

        if self.is_shown:
            msg = 'ImageView.show should only be called once.'
            warnings.warn(UserWarning(msg))
            return

        set_mpl_interactive()

        kwargs = {}
        if fignum is not None:
            kwargs['num'] = fignum
        if settings.imshow_figure_size is not None:
            kwargs['figsize'] = settings.imshow_figure_size
        plt.figure(**kwargs)

        if self.data_rgb is not None:
            self.show_data()
        if self.classes is not None:
            self.show_classes()

        if mode is None:
            self._guess_mode()
        else:
            self.set_display_mode(mode)

        self.axes.format_coord = self.format_coord

        self.init_callbacks()
        self.is_shown = True

    def init_callbacks(self):
        '''Creates the object's callback registry and default callbacks.'''
        from spectral import settings
        from matplotlib.cbook import CallbackRegistry

        self.callbacks = CallbackRegistry()

        # callbacks_common may have been set to a shared external registry
        # (e.g., to the callbacks_common member of another ImageView object). So
        # don't create it if it has already been set.
        if self.callbacks_common is None:
            self.callbacks_common = CallbackRegistry()

        # Keyboard callback
        self.cb_mouse = ImageViewMouseHandler(self)
        self.cb_mouse.connect()

        # Mouse callback
        self.cb_keyboard = ImageViewKeyboardHandler(self)
        self.cb_keyboard.connect()

        # Class update event callback
        def updater(*args, **kwargs):
            if self.classes is None:
                self.set_classes(args[0].classes)
            self.refresh()

        callback = MplCallback(registry=self.callbacks_common,
                               event='spy_classes_modified',
                               callback=updater)
        callback.connect()
        self.cb_classes_modified = callback

        if settings.imshow_enable_rectangle_selector is False:
            return
        try:
            from matplotlib.widgets import RectangleSelector
            self.selector = RectangleSelector(self.axes,
                                              self._select_rectangle,
                                              button=1,
                                              useblit=True,
                                              spancoords='data',
                                              drawtype='box',
                                              rectprops = \
                                                  self.selector_rectprops)
            self.selector.set_active(False)
        except:
            self.selector = None
            msg = 'Failed to create RectangleSelector object. Interactive ' \
              'pixel class labeling will be unavailable.'
            warn(msg)

    def label_region(self, rectangle, class_id):
        '''Assigns all pixels in the rectangle to the specified class.

        Arguments:

            `rectangle` (4-tuple of integers):

                Tuple or list defining the rectangle bounds. Should have the
                form (row_start, row_stop, col_start, col_stop), where the
                stop indices are not included (i.e., the effect is
                `classes[row_start:row_stop, col_start:col_stop] = id`.

            class_id (integer >= 0):

                The class to which pixels will be assigned.

        Returns the number of pixels reassigned (the number of pixels in the
        rectangle whose class has *changed* to `class_id`.
        '''
        if self.classes is None:
            self.classes = np.zeros(self.data_rgb.shape[:2], dtype=np.int16)
        r = rectangle
        n = np.sum(self.classes[r[0]:r[1], r[2]:r[3]] != class_id)
        if n > 0:
            self.classes[r[0]:r[1], r[2]:r[3]] = class_id
            event = SpyMplEvent('spy_classes_modified')
            event.classes = self.classes
            event.nchanged = n
            self.callbacks_common.process('spy_classes_modified', event)
            # Make selection rectangle go away.
            self.selector.to_draw.set_visible(False)
            self.refresh()
            return n
        return 0

    def _select_rectangle(self, event1, event2):
        if event1.inaxes is not self.axes or event2.inaxes is not self.axes:
            self.selection = None
            return
        (r1, c1) = xy_to_rowcol(event1.xdata, event1.ydata)
        (r2, c2) = xy_to_rowcol(event2.xdata, event2.ydata)
        (r1, r2) = sorted([r1, r2])
        (c1, c2) = sorted([c1, c2])
        if (r2 < 0) or (r1 >= self._image_shape[0]) or \
          (c2 < 0) or (c1 >= self._image_shape[1]):
            self.selection = None
            return
        r1 = max(r1, 0)
        r2 = min(r2, self._image_shape[0] - 1)
        c1 = max(c1, 0)
        c2 = min(c2, self._image_shape[1] - 1)
        print('Selected region: [%d: %d, %d: %d]' % (r1, r2 + 1, c1, c2 + 1))
        self.selection = [r1, r2 + 1, c1, c2 + 1]
        self.selector.set_active(False)
        # Make the rectangle display until at least the next event
        self.selector.to_draw.set_visible(True)
        self.selector.update()

    def _guess_mode(self):
        '''Select an appropriate display mode, based on current data.'''
        if self.data_rgb is not None:
            self.set_display_mode('data')
        elif self.classes is not None:
            self.set_display_mode('classes')
        else:
            raise Exception('Unable to display image: no data set.')

    def show_data(self):
        '''Show the image data.'''
        import matplotlib.pyplot as plt
        if self.data_axes is not None:
            msg = 'ImageView.show_data should only be called once.'
            warnings.warn(UserWarning(msg))
            return
        elif self.data_rgb is None:
            raise Exception('Unable to display data: data array not set.')
        if self.axes is not None:
            # A figure has already been created for the view. Make it current.
            plt.figure(self.axes.figure.number)
        self.imshow_data_kwargs['interpolation'] = self._interpolation
        self.data_axes = plt.imshow(self.data_rgb, **self.imshow_data_kwargs)
        if self.axes is None:
            self.axes = self.data_axes.axes

    def show_classes(self):
        '''Show the class values.'''
        import matplotlib.pyplot as plt
        from matplotlib.colors import ListedColormap, NoNorm
        from spectral import get_rgb

        if self.class_axes is not None:
            msg = 'ImageView.show_classes should only be called once.'
            warnings.warn(UserWarning(msg))
            return
        elif self.classes is None:
            raise Exception('Unable to display classes: class array not set.')

        cm = ListedColormap(np.array(self.class_colors) / 255.)
        self._update_class_rgb()
        kwargs = self.imshow_class_kwargs.copy()

        kwargs.update({
            'cmap': cm,
            'vmin': 0,
            'norm': NoNorm(),
            'interpolation': self._interpolation
        })
        if self.axes is not None:
            # A figure has already been created for the view. Make it current.
            plt.figure(self.axes.figure.number)
        self.class_axes = plt.imshow(self.class_rgb, **kwargs)
        if self.axes is None:
            self.axes = self.class_axes.axes
        self.class_axes.set_zorder(1)
        if self.display_mode == 'overlay':
            self.class_axes.set_alpha(self._class_alpha)
        else:
            self.class_axes.set_alpha(1)
        #self.class_axes.axes.set_axis_bgcolor('black')

    def refresh(self):
        '''Updates the displayed data (if it has been shown).'''
        if self.is_shown:
            self._update_class_rgb()
            if self.class_axes is not None:
                self.class_axes.set_data(self.class_rgb)
                self.class_axes.set_interpolation(self._interpolation)
            elif self.display_mode in ('classes', 'overlay'):
                self.show_classes()
            if self.data_axes is not None:
                self.data_axes.set_data(self.data_rgb)
                self.data_axes.set_interpolation(self._interpolation)
            elif self.display_mode in ('data', 'overlay'):
                self.show_data()
            self.axes.figure.canvas.draw()

    def _update_class_rgb(self):
        if self.display_mode == 'overlay':
            self.class_rgb = np.ma.array(self.classes,
                                         mask=(self.classes == 0))
        else:
            self.class_rgb = np.array(self.classes)

    def set_display_mode(self, mode):
        '''`mode` must be one of ("data", "classes", "overlay").'''
        if mode not in ('data', 'classes', 'overlay'):
            raise ValueError('Invalid display mode: ' + repr(mode))
        self.display_mode = mode

        show_data = mode in ('data', 'overlay')
        if self.data_axes is not None:
            self.data_axes.set_visible(show_data)

        show_classes = mode in ('classes', 'overlay')
        if self.classes is not None and self.class_axes is None:
            # Class data values were just set
            self.show_classes()
        if self.class_axes is not None:
            self.class_axes.set_visible(show_classes)
            if mode == 'classes':
                self.class_axes.set_alpha(1)
            else:
                self.class_axes.set_alpha(self._class_alpha)
        self.refresh()

    @property
    def class_alpha(self):
        '''alpha transparency for the class overlay.'''
        return self._class_alpha

    @class_alpha.setter
    def class_alpha(self, alpha):
        if alpha < 0 or alpha > 1:
            raise ValueError('Alpha value must be in range [0, 1].')
        self._class_alpha = alpha
        if self.class_axes is not None:
            self.class_axes.set_alpha(alpha)
        if self.is_shown:
            self.refresh()

    @property
    def interpolation(self):
        '''matplotlib pixel interpolation to use in the image display.'''
        return self._interpolation

    @interpolation.setter
    def interpolation(self, interpolation):
        if interpolation == self._interpolation:
            return
        self._interpolation = interpolation
        if not self.is_shown:
            return
        if self.data_axes is not None:
            self.data_axes.set_interpolation(interpolation)
        if self.class_axes is not None:
            self.class_axes.set_interpolation(interpolation)
        self.refresh()

    def set_title(self, s):
        if self.is_shown:
            self.axes.set_title(s)
            self.refresh()

    def open_zoom(self, center=None, size=None):
        '''Opens a separate window with a zoomed view.
        If a ctrl-lclick event occurs in the original view, the zoomed window
        will pan to the location of the click event.

        Arguments:

            `center` (two-tuple of int):

                Initial (row, col) of the zoomed view.

            `size` (int):

                Width and height (in source image pixels) of the initial
                zoomed view.

        Returns:

        A new ImageView object for the zoomed view.
        '''
        from spectral import settings
        import matplotlib.pyplot as plt
        if size is None:
            size = settings.imshow_zoom_pixel_width
        (nrows, ncols) = self._image_shape
        fig_kwargs = {}
        if settings.imshow_zoom_figure_width is not None:
            width = settings.imshow_zoom_figure_width
            fig_kwargs['figsize'] = (width, width)
        fig = plt.figure(**fig_kwargs)

        view = ImageView(source=self.source)
        view.set_data(self.data, self.bands, **self.rgb_kwargs)
        view.set_classes(self.classes, self.class_colors)
        view.imshow_data_kwargs = self.imshow_data_kwargs.copy()
        kwargs = {'extent': (-0.5, ncols - 0.5, nrows - 0.5, -0.5)}
        view.imshow_data_kwargs.update(kwargs)
        view.imshow_class_kwargs = self.imshow_class_kwargs.copy()
        view.imshow_class_kwargs.update(kwargs)
        view.set_display_mode(self.display_mode)
        view.callbacks_common = self.callbacks_common
        view.show(fignum=fig.number, mode=self.display_mode)
        view.axes.set_xlim(0, size)
        view.axes.set_ylim(size, 0)
        view.interpolation = 'nearest'
        if center is not None:
            view.pan_to(*center)
        view.cb_parent_pan = ParentViewPanCallback(view, self)
        view.cb_parent_pan.connect()
        return view

    def pan_to(self, row, col):
        '''Centers view on pixel coordinate (row, col).'''
        if self.axes is None:
            raise Exception('Cannot pan image until it is shown.')
        (xmin, xmax) = self.axes.get_xlim()
        (ymin, ymax) = self.axes.get_ylim()
        xrange_2 = (xmax - xmin) / 2.0
        yrange_2 = (ymax - ymin) / 2.0
        self.axes.set_xlim(col - xrange_2, col + xrange_2)
        self.axes.set_ylim(row - yrange_2, row + yrange_2)
        self.axes.figure.canvas.draw()

    def zoom(self, scale):
        '''Zooms view in/out (`scale` > 1 zooms in).'''
        (xmin, xmax) = self.axes.get_xlim()
        (ymin, ymax) = self.axes.get_ylim()
        x = (xmin + xmax) / 2.0
        y = (ymin + ymax) / 2.0
        dx = (xmax - xmin) / 2.0 / scale
        dy = (ymax - ymin) / 2.0 / scale

        self.axes.set_xlim(x - dx, x + dx)
        self.axes.set_ylim(y - dy, y + dy)
        self.refresh()

    def format_coord(self, x, y):
        '''Formats pixel coordinate string displayed in the window.'''
        (nrows, ncols) = self._image_shape
        if x < -0.5 or x > ncols - 0.5 or y < -0.5 or y > nrows - 0.5:
            return ""
        (r, c) = xy_to_rowcol(x, y)
        s = 'pixel=[%d,%d]' % (r, c)
        if self.classes is not None:
            try:
                s += ' class=%d' % self.classes[r, c]
            except:
                pass
        return s

    def __str__(self):
        meta = self.data_rgb_meta
        s = 'ImageView object:\n'
        if 'bands' in meta:
            s += '  {0:<20}:  {1}\n'.format("Display bands", meta['bands'])
        if self.interpolation == None:
            interp = "<default>"
        else:
            interp = self.interpolation
        s += '  {0:<20}:  {1}\n'.format("Interpolation", interp)
        if 'rgb range' in meta:
            s += '  {0:<20}:\n'.format("RGB data limits")
            for (c, r) in zip('RGB', meta['rgb range']):
                s += '    {0}: {1}\n'.format(c, str(r))
        return s

    def __repr__(self):
        return str(self)
Пример #2
0
class Cursor:
    """A cursor for selecting artists on a matplotlib figure.
    """

    _keep_alive = WeakKeyDictionary()

    def __init__(self,
                 artists,
                 *,
                 multiple=False,
                 highlight=False,
                 hover=False,
                 bindings=default_bindings):
        """Construct a cursor.

        Parameters
        ----------

        artists : List[Artist]
            A list of artists that can be selected by this cursor.

        multiple : bool
            Whether multiple artists can be "on" at the same time (defaults to
            False).

        highlight : bool
            Whether to also highlight the selected artist.  If so,
            "highlighter" artists will be placed as the first item in the
            :attr:`extras` attribute of the `Selection`.

        bindings : dict
            A mapping of button and keybindings to actions.  Valid entries are:

            =================== ===============================================
            'select'            mouse button to select an artist (default: 1)
            'deselect'          mouse button to deselect an artist (default: 3)
            'left'              move to the previous point in the selected
                                path, or to the left in the selected image
                                (default: shift+left)
            'right'             move to the next point in the selected path, or
                                to the right in the selected image
                                (default: shift+right)
            'up'                move up in the selected image
                                (default: shift+up)
            'down'              move down in the selected image
                                (default: shift+down)
            'toggle_visibility' toggle visibility of all cursors (default: d)
            'toggle_enabled'    toggle whether the cursor is active
                                (default: t)
            =================== ===============================================

        hover : bool
            Whether to select artists upon hovering instead of by clicking.
        """

        artists = list(artists)
        # Be careful with GC.
        self._artists = [weakref.ref(artist) for artist in artists]

        for artist in artists:
            type(self)._keep_alive.setdefault(artist, []).append(self)

        self._multiple = multiple
        self._highlight = highlight

        self._axes = {artist.axes for artist in artists}
        self._enabled = True
        self._selections = []
        self._callbacks = CallbackRegistry()

        connect_pairs = [("key_press_event", self._on_key_press)]
        if hover:
            if multiple:
                raise ValueError("`hover` and `multiple` are incompatible")
            connect_pairs += [("motion_notify_event",
                               self._on_select_button_press)]
        else:
            connect_pairs += [("button_press_event", self._on_button_press)]
        self._disconnect_cids = [
            partial(canvas.mpl_disconnect, canvas.mpl_connect(*pair))
            for pair in connect_pairs
            for canvas in {artist.figure.canvas
                           for artist in artists}
        ]

        bindings = {**default_bindings, **bindings}
        if set(bindings) != set(default_bindings):
            raise ValueError("Unknown bindings")
        actually_bound = {k: v for k, v in bindings.items() if v is not None}
        if len(set(actually_bound.values())) != len(actually_bound):
            raise ValueError("Duplicate bindings")
        self._bindings = bindings

    @property
    def enabled(self):
        """Whether clicks are registered for picking and unpicking events.
        """
        return self._enabled

    @enabled.setter
    def enabled(self, value):
        self._enabled = value

    @property
    def artists(self):
        """The tuple of selectable artists.
        """
        return tuple(filter(None, (ref() for ref in self._artists)))

    @property
    def selections(self):
        """The tuple of current `Selection`\\s.
        """
        return tuple(self._selections)

    def add_selection(self, pi):
        """Create an annotation for a `Selection` and register it.

        Returns a new `Selection`, that has been registered by the `Cursor`,
        with the added annotation set in the :attr:`annotation` field and, if
        applicable, the highlighting artist in the :attr:`extras` field.

        Emits the ``"add"`` event with the new `Selection` as argument.
        """
        # pi: "pick_info", i.e. an incomplete selection.
        ann = pi.artist.axes.annotate(_pick_info.get_ann_text(*pi),
                                      xy=pi.target,
                                      **default_annotation_kwargs)
        ann.draggable(use_blit=True)
        extras = []
        if self._highlight:
            extras.append(self.add_highlight(pi.artist))
        if not self._multiple:
            while self._selections:
                self._remove_selection(self._selections[-1])
        sel = pi._replace(annotation=ann, extras=extras)
        self._selections.append(sel)
        self._callbacks.process("add", sel)
        sel.artist.figure.canvas.draw_idle()
        return sel

    def add_highlight(self, artist):
        """Create, add and return a highlighting artist.

        It is up to the caller to register the artist with the proper
        `Selection` in order to ensure cleanup upon deselection.
        """
        hl = copy.copy(artist)
        hl.set(**default_highlight_kwargs)
        artist.axes.add_artist(hl)
        return hl

    def connect(self, event, func=None):
        """Connect a callback to a `Cursor` event; return the callback id.

        Two classes of event can be emitted, both with a `Selection` as single
        argument:

            - ``"add"`` when a `Selection` is added, and
            - ``"remove"`` when a `Selection` is removed.

        The callback registry relies on :mod:`matplotlib`'s implementation; in
        particular, only weak references are kept for bound methods.

        This method is can also be used as a decorator::

            @cursor.connect("add")
            def on_add(sel):
                ...
        """
        if event not in ["add", "remove"]:
            raise ValueError("Invalid cursor event: {}".format(event))
        if func is None:
            return partial(self.connect, event)
        return self._callbacks.connect(event, func)

    def disconnect(self, cid):
        """Disconnect a previously connected callback id.
        """
        self._callbacks.disconnect(cid)

    def remove(self):
        """Remove all `Selection`\\s and disconnect all callbacks.
        """
        for disconnect_cid in self._disconnect_cids:
            disconnect_cid()
        while self._selections:
            self._remove_selection(self._selections[-1])

    def _on_button_press(self, event):
        if event.button == self._bindings["select"]:
            self._on_select_button_press(event)
        if event.button == self._bindings["deselect"]:
            self._on_deselect_button_press(event)

    def _filter_mouse_event(self, event):
        # Accept the event iff we are enabled, and either
        #   - no other widget is active, and this is not the second click of a
        #     double click (to prevent double selection), or
        #   - another widget is active, and this is a double click (to bypass
        #     the widget lock).
        return (self.enabled
                and event.canvas.widgetlock.locked() == event.dblclick)

    def _on_select_button_press(self, event):
        if not self._filter_mouse_event(event):
            return
        # Work around lack of support for twinned axes.
        per_axes_event = {
            ax: _reassigned_axes_event(event, ax)
            for ax in self._axes
        }
        pis = []
        for artist in self.artists:
            if (artist.axes is None  # Removed or figure-level artist.
                    or event.canvas is not artist.figure.canvas
                    or not artist.axes.contains(event)[0]):  # Cropped by axes.
                continue
            pi = _pick_info.compute_pick(artist, per_axes_event[artist.axes])
            if pi:
                pis.append(pi)
        if not pis:
            return
        self.add_selection(min(pis, key=lambda pi: pi.dist))

    def _on_deselect_button_press(self, event):
        if not self._filter_mouse_event(event):
            return
        for sel in self._selections:
            ann = sel.annotation
            if event.canvas is not ann.figure.canvas:
                continue
            contained, _ = ann.contains(event)
            if contained:
                self._remove_selection(sel)

    def _on_key_press(self, event):
        if event.key == self._bindings["toggle_enabled"]:
            self.enabled = not self.enabled
        elif event.key == self._bindings["toggle_visibility"]:
            for sel in self._selections:
                sel.annotation.set_visible(not sel.annotation.get_visible())
                sel.annotation.figure.canvas.draw_idle()
        if self._selections:
            sel = self._selections[-1]
        else:
            return
        for key in ["left", "right", "up", "down"]:
            if event.key == self._bindings[key]:
                self._remove_selection(sel)
                self.add_selection(_pick_info.move(*sel, key=key))
                break

    def _remove_selection(self, sel):
        self._selections.remove(sel)
        # Work around matplotlib/matplotlib#6785.
        draggable = sel.annotation._draggable
        try:
            draggable.disconnect()
            sel.annotation.figure.canvas.mpl_disconnect(
                sel.annotation._draggable._c1)
        except AttributeError:
            pass
        # (end of workaround).
        # <artist>.figure will be unset so we save them first.
        figures = {artist.figure for artist in [sel.annotation, *sel.extras]}
        # ValueError is raised if the artist has already been removed.
        with suppress(ValueError):
            sel.annotation.remove()
        for artist in sel.extras:
            with suppress(ValueError):
                artist.remove()
        self._callbacks.process("remove", sel)
        for figure in figures:
            figure.canvas.draw_idle()
Пример #3
0
class scatter_selector(AxesWidget):
    """
    A widget for selecting a point in a scatter plot. callback will receive (index, (x, y))
    """

    def __init__(self, ax, x, y, pickradius=5, which_button=1, **kwargs):
        """
        Create the scatter plot and selection machinery.

        Parameters
        ----------
        ax : Axes
            The Axes on which to make the scatter plot
        x, y : float or array-like, shape (n, )
            The data positions.
        pickradius : float
            Pick radius, in points.
        which_button : int, default: 1
            Where 1=left, 2=middle, 3=right

        Other Parameters
        ----------------
        **kwargs : arguments to scatter
            Other keyword arguments are passed directly to the ``ax.scatter`` command

        """
        super().__init__(ax)
        self.scatter = ax.scatter(x, y, **kwargs, picker=True)
        self.scatter.set_pickradius(pickradius)
        self._observers = CallbackRegistry()
        self._x = x
        self._y = y
        self._button = which_button
        self.connect_event("pick_event", self._on_pick)
        self._init_val()

    def _init_val(self):
        self.val = (0, (self._x[0], self._y[0]))

    def _on_pick(self, event):
        if event.mouseevent.button == self._button:
            idx = event.ind[0]
            x = self._x[idx]
            y = self._y[idx]
            self._process(idx, (x, y))

    def _process(idx, val):
        self._observers.process("picked", idx, val)

    def on_changed(self, func):
        """
        When a point is clicked calll *func* with the newly selected point

        Parameters
        ----------
        func : callable
            Function to call when slider is changed.
            The function must accept a (int, tuple(float, float)) as its arguments.

        Returns
        -------
        int
            Connection id (which can be used to disconnect *func*)
        """
        return self._observers.connect("picked", lambda idx, val: func(idx, val))
Пример #4
0
class Cursor:
    """A cursor for selecting Matplotlib artists.

    Attributes
    ----------
    bindings : dict
        See the *bindings* keyword argument to the constructor.
    annotation_kwargs : dict
        See the *annotation_kwargs* keyword argument to the constructor.
    annotation_positions : dict
        See the *annotation_positions* keyword argument to the constructor.
    highlight_kwargs : dict
        See the *highlight_kwargs* keyword argument to the constructor.
    """

    _keep_alive = WeakKeyDictionary()

    def __init__(self,
                 artists,
                 *,
                 multiple=False,
                 highlight=False,
                 hover=False,
                 bindings=None,
                 annotation_kwargs=None,
                 annotation_positions=None,
                 highlight_kwargs=None):
        """Construct a cursor.

        Parameters
        ----------

        artists : List[Artist]
            A list of artists that can be selected by this cursor.

        multiple : bool, optional
            Whether multiple artists can be "on" at the same time (defaults to
            False).

        highlight : bool, optional
            Whether to also highlight the selected artist.  If so,
            "highlighter" artists will be placed as the first item in the
            :attr:`extras` attribute of the `Selection`.

        hover : bool, optional
            Whether to select artists upon hovering instead of by clicking.
            (Hovering over an artist while a button is pressed will not trigger
            a selection; right clicking on an annotation will still remove it.)

        bindings : dict, optional
            A mapping of button and keybindings to actions.  Valid entries are:

            ================ ==================================================
            'select'         mouse button to select an artist
                             (default: 1)
            'deselect'       mouse button to deselect an artist
                             (default: 3)
            'left'           move to the previous point in the selected path,
                             or to the left in the selected image
                             (default: shift+left)
            'right'          move to the next point in the selected path, or to
                             the right in the selected image
                             (default: shift+right)
            'up'             move up in the selected image
                             (default: shift+up)
            'down'           move down in the selected image
                             (default: shift+down)
            'toggle_enabled' toggle whether the cursor is active
                             (default: e)
            'toggle_visible' toggle default cursor visibility and apply it to
                             all cursors (default: v)
            ================ ==================================================

            Missing entries will be set to the defaults.  In order to not
            assign any binding to an action, set it to ``None``.

        annotation_kwargs : dict, optional
            Keyword argments passed to the `annotate
            <matplotlib.axes.Axes.annotate>` call.

        annotation_positions : List[dict], optional
            List of positions tried by the annotation positioning algorithm.

        highlight_kwargs : dict, optional
            Keyword arguments used to create a highlighted artist.
        """

        artists = list(artists)
        # Be careful with GC.
        self._artists = [weakref.ref(artist) for artist in artists]

        for artist in artists:
            type(self)._keep_alive.setdefault(artist, set()).add(self)

        self._multiple = multiple
        self._highlight = highlight

        self._visible = True
        self._enabled = True
        self._selections = []
        self._last_auto_position = None
        self._callbacks = CallbackRegistry()

        connect_pairs = [("key_press_event", self._on_key_press)]
        if hover:
            if multiple:
                raise ValueError("'hover' and 'multiple' are incompatible")
            connect_pairs += [
                ("motion_notify_event", self._hover_handler),
                ("button_press_event", self._hover_handler),
            ]
        else:
            connect_pairs += [("button_press_event", self._nonhover_handler)]
        self._disconnectors = [
            partial(canvas.mpl_disconnect, canvas.mpl_connect(*pair))
            for pair in connect_pairs
            for canvas in {artist.figure.canvas
                           for artist in artists}
        ]

        bindings = dict(
            ChainMap(bindings if bindings is not None else {},
                     _default_bindings))
        unknown_bindings = set(bindings) - set(_default_bindings)
        if unknown_bindings:
            raise ValueError("Unknown binding(s): {}".format(", ".join(
                sorted(unknown_bindings))))
        duplicate_bindings = [
            k for k, v in Counter(list(bindings.values())).items() if v > 1
        ]
        if duplicate_bindings:
            raise ValueError("Duplicate binding(s): {}".format(", ".join(
                sorted(map(str, duplicate_bindings)))))
        self.bindings = bindings

        self.annotation_kwargs = (annotation_kwargs
                                  if annotation_kwargs is not None else
                                  copy.deepcopy(_default_annotation_kwargs))
        self.annotation_positions = (
            annotation_positions if annotation_positions is not None else
            copy.deepcopy(_default_annotation_positions))
        self.highlight_kwargs = (highlight_kwargs
                                 if highlight_kwargs is not None else
                                 copy.deepcopy(_default_highlight_kwargs))

    @property
    def artists(self):
        """The tuple of selectable artists.
        """
        # Work around matplotlib/matplotlib#6982: `cla()` does not clear
        # `.axes`.
        return tuple(filter(_is_alive, (ref() for ref in self._artists)))

    @property
    def enabled(self):
        """Whether clicks are registered for picking and unpicking events.
        """
        return self._enabled

    @enabled.setter
    def enabled(self, value):
        self._enabled = value

    @property
    def selections(self):
        """The tuple of current `Selection`\\s.
        """
        for sel in self._selections:
            if sel.annotation.axes is None:
                raise RuntimeError("Annotation unexpectedly removed; "
                                   "use 'cursor.remove_selection' instead")
        return tuple(self._selections)

    @property
    def visible(self):
        """Whether selections are visible by default.

        Setting this property also updates the visibility status of current
        selections.
        """
        return self._visible

    @visible.setter
    def visible(self, value):
        self._visible = value
        for sel in self.selections:
            sel.annotation.set_visible(value)
            sel.annotation.figure.canvas.draw_idle()

    def add_selection(self, pi):
        """Create an annotation for a `Selection` and register it.

        Returns a new `Selection`, that has been registered by the `Cursor`,
        with the added annotation set in the :attr:`annotation` field and, if
        applicable, the highlighting artist in the :attr:`extras` field.

        Emits the ``"add"`` event with the new `Selection` as argument.  When
        the event is emitted, the position of the annotation is temporarily
        set to ``(nan, nan)``; if this position is not explicitly set by a
        callback, then a suitable position will be automatically computed.

        Likewise, if the text alignment is not explicitly set but the position
        is, then a suitable alignment will be automatically computed.
        """
        # pi: "pick_info", i.e. an incomplete selection.
        # Pre-fetch the figure and axes, as callbacks may actually unset them.
        figure = pi.artist.figure
        axes = pi.artist.axes
        if axes.get_renderer_cache() is None:
            figure.canvas.draw()  # Needed by draw_artist below anyways.
        renderer = pi.artist.axes.get_renderer_cache()
        ann = pi.artist.axes.annotate(_pick_info.get_ann_text(*pi),
                                      xy=pi.target,
                                      xytext=(np.nan, np.nan),
                                      ha=_MarkedStr("center"),
                                      va=_MarkedStr("center"),
                                      visible=self.visible,
                                      **self.annotation_kwargs)
        ann.draggable(use_blit=True)
        extras = []
        if self._highlight:
            hl = self.add_highlight(*pi)
            if hl:
                extras.append(hl)
        sel = pi._replace(annotation=ann, extras=extras)
        self._selections.append(sel)
        self._callbacks.process("add", sel)

        # Check that `ann.axes` is still set, as callbacks may have removed the
        # annotation.
        if ann.axes and ann.xyann == (np.nan, np.nan):
            fig_bbox = figure.get_window_extent()
            ax_bbox = axes.get_window_extent()
            overlaps = []
            for idx, annotation_position in enumerate(
                    self.annotation_positions):
                ann.set(**annotation_position)
                # Work around matplotlib/matplotlib#7614: position update is
                # missing.
                ann.update_positions(renderer)
                bbox = ann.get_window_extent(renderer)
                overlaps.append((
                    _get_rounded_intersection_area(fig_bbox, bbox),
                    _get_rounded_intersection_area(ax_bbox, bbox),
                    # Avoid needlessly jumping around by breaking ties using
                    # the last used position as default.
                    idx == self._last_auto_position,
                ))
            auto_position = max(range(len(overlaps)), key=overlaps.__getitem__)
            ann.set(**self.annotation_positions[auto_position])
            self._last_auto_position = auto_position
        else:
            if isinstance(ann.get_ha(), _MarkedStr):
                ann.set_ha({
                    -1: "right",
                    0: "center",
                    1: "left"
                }[np.sign(np.nan_to_num(ann.xyann[0]))])
            if isinstance(ann.get_va(), _MarkedStr):
                ann.set_va({
                    -1: "top",
                    0: "center",
                    1: "bottom"
                }[np.sign(np.nan_to_num(ann.xyann[1]))])

        if (extras or len(self.selections) > 1 and not self._multiple
                or not figure.canvas.supports_blit):
            # Either:
            #  - there may be more things to draw, or
            #  - annotation removal will make a full redraw necessary, or
            #  - blitting is not (yet) supported.
            figure.canvas.draw_idle()
        elif ann.axes:
            # Fast path, only needed if the annotation has not been immediately
            # removed.
            figure.draw_artist(ann)
            # Explicit argument needed on MacOSX backend.
            figure.canvas.blit(figure.bbox)
        # Removal comes after addition so that the fast blitting path works.
        if not self._multiple:
            for sel in self.selections[:-1]:
                self.remove_selection(sel)
        return sel

    def add_highlight(self, artist, *args, **kwargs):
        """Create, add and return a highlighting artist.

        This method is should be called with an "unpacked" `Selection`,
        possibly with some fields set to None.

        It is up to the caller to register the artist with the proper
        `Selection` in order to ensure cleanup upon deselection.
        """
        hl = _pick_info.make_highlight(
            artist, *args,
            **ChainMap({"highlight_kwargs": self.highlight_kwargs}, kwargs))
        if hl:
            artist.axes.add_artist(hl)
            return hl

    def connect(self, event, func=None):
        """Connect a callback to a `Cursor` event; return the callback id.

        Two classes of event can be emitted, both with a `Selection` as single
        argument:

            - ``"add"`` when a `Selection` is added, and
            - ``"remove"`` when a `Selection` is removed.

        The callback registry relies on Matplotlib's implementation; in
        particular, only weak references are kept for bound methods.

        This method is can also be used as a decorator::

            @cursor.connect("add")
            def on_add(sel):
                ...

        Examples of callbacks::

            # Change the annotation text and alignment:
            lambda sel: sel.annotation.set(
                text=sel.artist.get_label(),  # or use e.g. sel.target.index
                ha="center", va="bottom")

            # Make label non-draggable:
            lambda sel: sel.draggable(False)
        """
        if event not in ["add", "remove"]:
            raise ValueError("Invalid cursor event: {}".format(event))
        if func is None:
            return partial(self.connect, event)
        return self._callbacks.connect(event, func)

    def disconnect(self, cid):
        """Disconnect a previously connected callback id.
        """
        self._callbacks.disconnect(cid)

    def remove(self):
        """Remove a cursor.

        Remove all `Selection`\\s, disconnect all callbacks, and allow the
        cursor to be garbage collected.
        """
        for disconnectors in self._disconnectors:
            disconnectors()
        for sel in self.selections:
            self.remove_selection(sel)
        for s in type(self)._keep_alive.values():
            with suppress(KeyError):
                s.remove(self)

    def _nonhover_handler(self, event):
        if event.name == "button_press_event":
            if event.button == self.bindings["select"]:
                self._on_select_button_press(event)
            if event.button == self.bindings["deselect"]:
                self._on_deselect_button_press(event)

    def _hover_handler(self, event):
        if event.name == "motion_notify_event" and event.button is None:
            # Filter away events where the mouse is pressed, in particular to
            # avoid conflicts between hover and draggable.
            self._on_select_button_press(event)
        elif (event.name == "button_press_event"
              and event.button == self.bindings["deselect"]):
            # Still allow removing the annotation by right clicking.
            self._on_deselect_button_press(event)

    def _filter_mouse_event(self, event):
        # Accept the event iff we are enabled, and either
        #   - no other widget is active, and this is not the second click of a
        #     double click (to prevent double selection), or
        #   - another widget is active, and this is a double click (to bypass
        #     the widget lock).
        return self.enabled and event.canvas.widgetlock.locked(
        ) == event.dblclick

    def _on_select_button_press(self, event):
        if not self._filter_mouse_event(event):
            return
        # Work around lack of support for twinned axes.
        per_axes_event = {
            ax: _reassigned_axes_event(event, ax)
            for ax in {artist.axes
                       for artist in self.artists}
        }
        pis = []
        for artist in self.artists:
            if (artist.axes is None  # Removed or figure-level artist.
                    or event.canvas is not artist.figure.canvas
                    or not artist.axes.contains(event)[0]):  # Cropped by axes.
                continue
            pi = _pick_info.compute_pick(artist, per_axes_event[artist.axes])
            if pi:
                pis.append(pi)
        if not pis:
            return
        self.add_selection(min(pis, key=lambda pi: pi.dist))

    def _on_deselect_button_press(self, event):
        if not self._filter_mouse_event(event):
            return
        for sel in self.selections:
            ann = sel.annotation
            if event.canvas is not ann.figure.canvas:
                continue
            contained, _ = ann.contains(event)
            if contained:
                self.remove_selection(sel)

    def _on_key_press(self, event):
        if event.key == self.bindings["toggle_enabled"]:
            self.enabled = not self.enabled
        elif event.key == self.bindings["toggle_visible"]:
            self.visible = not self.visible
        try:
            sel = self.selections[-1]
        except IndexError:
            return
        for key in ["left", "right", "up", "down"]:
            if event.key == self.bindings[key]:
                self.remove_selection(sel)
                self.add_selection(_pick_info.move(*sel, key=key))
                break

    def remove_selection(self, sel):
        """Remove a `Selection`.
        """
        self._selections.remove(sel)
        # <artist>.figure will be unset so we save them first.
        figures = {artist.figure for artist in [sel.annotation] + sel.extras}
        # ValueError is raised if the artist has already been removed.
        with suppress(ValueError):
            sel.annotation.remove()
        for artist in sel.extras:
            with suppress(ValueError):
                artist.remove()
        self._callbacks.process("remove", sel)
        for figure in figures:
            figure.canvas.draw_idle()
"""
cbook即为cookbook,是一些小工具组成的库

"""
from matplotlib.cbook import CallbackRegistry

callbacks = CallbackRegistry()
sum = lambda x, y: print(f'{x}+{y}={x + y}')
mul = lambda x, y: print(f"{x} * {y}={x * y}")
id_sum = callbacks.connect("sum", sum)
id_mul = callbacks.connect("mul", mul)
callbacks.process('sum', 3, 4)
callbacks.process("mul", 5, 6)
callbacks.disconnect(id_sum)
callbacks.process("sum", 7, 8)
Пример #6
0
class ImageView(object):
    '''Class to manage events and data associated with image raster views.

    In most cases, it is more convenient to simply call :func:`~spectral.graphics.spypylab.imshow`,
    which creates, displays, and returns an :class:`ImageView` object. Creating
    an :class:`ImageView` object directly (or creating an instance of a subclass)
    enables additional customization of the image display (e.g., overriding
    default event handlers). If the object is created directly, call the
    :meth:`show` method to display the image. The underlying image display
    functionality is implemented via :func:`matplotlib.pyplot.imshow`.
    '''
    selector_rectprops = dict(facecolor='red', edgecolor = 'black',
                              alpha=0.5, fill=True)
    selector_lineprops = dict(color='black', linestyle='-',
                              linewidth = 2, alpha=0.5)
    def __init__(self, data=None, bands=None, classes=None, source=None,
                 **kwargs):
        '''
        Arguments:

            `data` (ndarray or :class:`SpyFile`):

                The source of RGB bands to be displayed. with shape (R, C, B).
                If the shape is (R, C, 3), the last dimension is assumed to
                provide the red, green, and blue bands (unless the `bands`
                argument is provided). If :math:`B > 3` and `bands` is not
                provided, the first, middle, and last band will be used.

            `bands` (triplet of integers):

                Specifies which bands in `data` should be displayed as red,
                green, and blue, respectively.

            `classes` (ndarray of integers):

                An array of integer-valued class labels with shape (R, C). If
                the `data` argument is provided, the shape must match the first
                two dimensions of `data`.

            `source` (ndarray or :class:`SpyFile`):

                The source of spectral data associated with the image display.
                This optional argument is used to access spectral data (e.g., to
                generate a spectrum plot when a user double-clicks on the image
                display.

        Keyword arguments:

            Any keyword that can be provided to :func:`~spectral.graphics.graphics.get_rgb`
            or :func:`matplotlib.imshow`.
        '''

        import spectral
        from spectral import settings
        self.is_shown = False
        self.imshow_data_kwargs = {'cmap': settings.imshow_float_cmap}
        self.rgb_kwargs = {}
        self.imshow_class_kwargs = {'zorder': 1}

        self.data = data
        self.data_rgb = None
        self.data_rgb_meta = {}
        self.classes = None
        self.class_rgb = None
        self.source = None
        self.bands = bands
        self.data_axes = None
        self.class_axes = None
        self.axes = None
        self._image_shape = None
        self.display_mode = None
        self._interpolation = None
        self.selection = None
        self.interpolation = kwargs.get('interpolation',
                                        settings.imshow_interpolation)
        
        if data is not None:
            self.set_data(data, bands, **kwargs)
        if classes is not None:
            self.set_classes(classes, **kwargs)
        if source is not None:
            self.set_source(source)

        self.class_colors = spectral.spy_colors
 
        self.spectrum_plot_fig_id = None
        self.parent = None
        self.selector = None
        self._on_parent_click_cid = None
        self._class_alpha = settings.imshow_class_alpha

        # Callbacks for events associated specifically with this window.
        self.callbacks = None
        
        # A sharable callback registry for related windows. If this
        # CallbackRegistry is set prior to calling ImageView.show (e.g., by
        # setting it equal to the `callbacks_common` member of another
        # ImageView object), then the registry will be shared. Otherwise, a new
        # callback registry will be created for this ImageView.
        self.callbacks_common = None

        check_disable_mpl_callbacks()

    def set_data(self, data, bands=None, **kwargs):
        '''Sets the data to be shown in the RGB channels.
        
        Arguments:

            `data` (ndarray or SpyImage):

                If `data` has more than 3 bands, the `bands` argument can be
                used to specify which 3 bands to display. `data` will be
                passed to `get_rgb` prior to display.

            `bands` (3-tuple of int):

                Indices of the 3 bands to display from `data`.

        Keyword Arguments:

            Any valid keyword for `get_rgb` or `matplotlib.imshow` can be
            given.
        '''
        from .graphics import _get_rgb_kwargs

        self.data = data
        self.bands = bands

        rgb_kwargs = {}
        for k in _get_rgb_kwargs:
            if k in kwargs:
                rgb_kwargs[k] = kwargs.pop(k)
        self.set_rgb_options(**rgb_kwargs)

        self._update_data_rgb()

        if self._image_shape is None:
            self._image_shape = data.shape[:2]
        elif data.shape[:2] != self._image_shape:
            raise ValueError('Image shape is inconsistent with previously ' \
                             'set data.')
        self.imshow_data_kwargs.update(kwargs)
        if 'interpolation' in self.imshow_data_kwargs:
            self.interpolation = self.imshow_data_kwargs['interpolation']
            self.imshow_data_kwargs.pop('interpolation')

        if len(kwargs) > 0 and self.is_shown:
            msg = 'Keyword args to set_data only have an effect if ' \
              'given before the image is shown.'
            warnings.warn(UserWarning(msg))
        if self.is_shown:
            self.refresh()

    def set_rgb_options(self, **kwargs):
        '''Sets parameters affecting RGB display of data.

        Accepts any keyword supported by :func:`~spectral.graphics.graphics.get_rgb`.
        '''
        from .graphics import _get_rgb_kwargs

        for k in kwargs:
            if k not in _get_rgb_kwargs:
                raise ValueError('Unexpected keyword: {0}'.format(k))
        self.rgb_kwargs = kwargs.copy()
        if self.is_shown:
            self._update_data_rgb()
            self.refresh()
        
    def _update_data_rgb(self):
        '''Regenerates the RGB values for display.'''
        from .graphics import get_rgb_meta

        (self.data_rgb, self.data_rgb_meta) = \
          get_rgb_meta(self.data, self.bands, **self.rgb_kwargs)

        # If it is a gray-scale image, only keep the first RGB component so
        # matplotlib imshow's cmap can still be used.
        if self.data_rgb_meta['mode'] == 'monochrome' and \
           self.data_rgb.ndim ==3:
          (self.bands is not None and len(self.bands) == 1)

    def set_classes(self, classes, colors=None, **kwargs):
        '''Sets the array of class values associated with the image data.

        Arguments:

            `classes` (ndarray of int):

                `classes` must be an integer-valued array with the same
                number rows and columns as the display data (if set).

            `colors`: (array or 3-tuples):

                Color triplets (with values in the range [0, 255]) that
                define the colors to be associatd with the integer indices
                in `classes`.

        Keyword Arguments:

            Any valid keyword for `matplotlib.imshow` can be provided.
        '''
        from .graphics import _get_rgb_kwargs
        self.classes = classes
        if classes is None:
            return
        if self._image_shape is None:
            self._image_shape = classes.shape[:2]
        elif classes.shape[:2] != self._image_shape:
            raise ValueError('Class data shape is inconsistent with ' \
                             'previously set data.')
        if colors is not None:
            self.class_colors = colors

        kwargs = dict([item for item in list(kwargs.items()) if item[0] not in \
                       _get_rgb_kwargs])
        self.imshow_class_kwargs.update(kwargs)

        if 'interpolation' in self.imshow_class_kwargs:
            self.interpolation = self.imshow_class_kwargs['interpolation']
            self.imshow_class_kwargs.pop('interpolation')

        if len(kwargs) > 0 and self.is_shown:
            msg = 'Keyword args to set_classes only have an effect if ' \
              'given before the image is shown.'
            warnings.warn(UserWarning(msg))
        if self.is_shown:
            self.refresh()

    def set_source(self, source):
        '''Sets the image data source (used for accessing spectral data).

        Arguments:

            `source` (ndarray or :class:`SpyFile`):

                The source for spectral data associated with the view.
        '''
        self.source = source
    
    def show(self, mode=None, fignum=None):
        '''Renders the image data.

        Arguments:

            `mode` (str):

                Must be one of:

                    "data":          Show the data RGB

                    "classes":       Shows indexed color for `classes`

                    "overlay":       Shows class colors overlaid on data RGB.

                If `mode` is not provided, a mode will be automatically
                selected, based on the data set in the ImageView.

            `fignum` (int):

                Figure number of the matplotlib figure in which to display
                the ImageView. If not provided, a new figure will be created.
        '''
        import matplotlib.pyplot as plt
        from spectral import settings

        if self.is_shown:
            msg = 'ImageView.show should only be called once.'
            warnings.warn(UserWarning(msg))
            return

        set_mpl_interactive()

        kwargs = {}
        if fignum is not None:
            kwargs['num'] = fignum
        if settings.imshow_figure_size is not None:
            kwargs['figsize'] = settings.imshow_figure_size
        plt.figure(**kwargs)
            
        if self.data_rgb is not None:
            self.show_data()
        if self.classes is not None:
            self.show_classes()

        if mode is None:
            self._guess_mode()
        else:
            self.set_display_mode(mode)

        self.axes.format_coord = self.format_coord

        self.init_callbacks()
        self.is_shown = True

    def init_callbacks(self):
        '''Creates the object's callback registry and default callbacks.'''
        from spectral import settings
        from matplotlib.cbook import CallbackRegistry
        
        self.callbacks = CallbackRegistry()

        # callbacks_common may have been set to a shared external registry
        # (e.g., to the callbacks_common member of another ImageView object). So
        # don't create it if it has already been set.
        if self.callbacks_common is None:
            self.callbacks_common = CallbackRegistry()

        # Keyboard callback
        self.cb_mouse = ImageViewMouseHandler(self)
        self.cb_mouse.connect()

        # Mouse callback
        self.cb_keyboard = ImageViewKeyboardHandler(self)
        self.cb_keyboard.connect()

        # Class update event callback
        def updater(*args, **kwargs):
            if self.classes is None:
                self.set_classes(args[0].classes)
            self.refresh()
        callback = MplCallback(registry=self.callbacks_common,
                               event='spy_classes_modified',
                               callback=updater)
        callback.connect()
        self.cb_classes_modified = callback


        if settings.imshow_enable_rectangle_selector is False:
            return
        try:
            from matplotlib.widgets import RectangleSelector
            self.selector = RectangleSelector(self.axes,
                                              self._select_rectangle,
                                              button=1,
                                              useblit=True,
                                              spancoords='data',
                                              drawtype='box',
                                              rectprops = \
                                                  self.selector_rectprops)
            self.selector.set_active(False)
        except:
            self.selector = None
            msg = 'Failed to create RectangleSelector object. Interactive ' \
              'pixel class labeling will be unavailable.'
            warn(msg)

    def label_region(self, rectangle, class_id):
        '''Assigns all pixels in the rectangle to the specified class.

        Arguments:

            `rectangle` (4-tuple of integers):

                Tuple or list defining the rectangle bounds. Should have the
                form (row_start, row_stop, col_start, col_stop), where the
                stop indices are not included (i.e., the effect is
                `classes[row_start:row_stop, col_start:col_stop] = id`.

            class_id (integer >= 0):

                The class to which pixels will be assigned.

        Returns the number of pixels reassigned (the number of pixels in the
        rectangle whose class has *changed* to `class_id`.
        '''
        if self.classes is None:
            self.classes = np.zeros(self.data_rgb.shape[:2], dtype=np.int16)
        r = rectangle
        n = np.sum(self.classes[r[0]:r[1], r[2]:r[3]] != class_id)
        if n > 0:
            self.classes[r[0]:r[1], r[2]:r[3]] = class_id
            event = SpyMplEvent('spy_classes_modified')
            event.classes = self.classes
            event.nchanged = n
            self.callbacks_common.process('spy_classes_modified', event)
            # Make selection rectangle go away.
            self.selector.to_draw.set_visible(False)
            self.refresh()
            return n
        return 0

    def _select_rectangle(self, event1, event2):
        if event1.inaxes is not self.axes or event2.inaxes is not self.axes:
            self.selection = None
            return
        (r1, c1) = xy_to_rowcol(event1.xdata, event1.ydata)
        (r2, c2) = xy_to_rowcol(event2.xdata, event2.ydata)
        (r1, r2) = sorted([r1, r2])
        (c1, c2) = sorted([c1, c2])
        if (r2 < 0) or (r1 >= self._image_shape[0]) or \
          (c2 < 0) or (c1 >= self._image_shape[1]):
          self.selection = None
          return
        r1 = max(r1, 0)
        r2 = min(r2, self._image_shape[0] - 1)
        c1 = max(c1, 0)
        c2 = min(c2, self._image_shape[1] - 1)
        print('Selected region: [%d: %d, %d: %d]' % (r1, r2 + 1, c1, c2 + 1))
        self.selection = [r1, r2 + 1, c1, c2 + 1]
        self.selector.set_active(False)
        # Make the rectangle display until at least the next event
        self.selector.to_draw.set_visible(True)
        self.selector.update()
    
    def _guess_mode(self):
        '''Select an appropriate display mode, based on current data.'''
        if self.data_rgb is not None:
            self.set_display_mode('data')
        elif self.classes is not None:
            self.set_display_mode('classes')
        else:
            raise Exception('Unable to display image: no data set.')

    def show_data(self):
        '''Show the image data.'''
        import matplotlib.pyplot as plt
        if self.data_axes is not None:
            msg = 'ImageView.show_data should only be called once.'
            warnings.warn(UserWarning(msg))
            return
        elif self.data_rgb is None:
            raise Exception('Unable to display data: data array not set.')
        if self.axes is not None:
            # A figure has already been created for the view. Make it current.
            plt.figure(self.axes.figure.number)
        self.imshow_data_kwargs['interpolation'] = self._interpolation
        self.data_axes = plt.imshow(self.data_rgb, **self.imshow_data_kwargs)
        if self.axes is None:
            self.axes = self.data_axes.axes

    def show_classes(self):
        '''Show the class values.'''
        import matplotlib.pyplot as plt
        from matplotlib.colors import ListedColormap, NoNorm
        from spectral import get_rgb

        if self.class_axes is not None:
            msg = 'ImageView.show_classes should only be called once.'
            warnings.warn(UserWarning(msg))
            return
        elif self.classes is None:
            raise Exception('Unable to display classes: class array not set.')

        cm = ListedColormap(np.array(self.class_colors) / 255.)
        self._update_class_rgb()
        kwargs = self.imshow_class_kwargs.copy()

        kwargs.update({'cmap': cm, 'vmin': 0, 'norm': NoNorm(),
                       'interpolation': self._interpolation})
        if self.axes is not None:
            # A figure has already been created for the view. Make it current.
            plt.figure(self.axes.figure.number)
        self.class_axes = plt.imshow(self.class_rgb, **kwargs)
        if self.axes is None:
            self.axes = self.class_axes.axes
        self.class_axes.set_zorder(1)
        if self.display_mode == 'overlay':
            self.class_axes.set_alpha(self._class_alpha)
        else:
            self.class_axes.set_alpha(1)
        #self.class_axes.axes.set_axis_bgcolor('black')

    def refresh(self):
        '''Updates the displayed data (if it has been shown).'''
        if self.is_shown:
            self._update_class_rgb()
            if self.class_axes is not None:
                self.class_axes.set_data(self.class_rgb)
                self.class_axes.set_interpolation(self._interpolation)
            elif self.display_mode in ('classes', 'overlay'):
                self.show_classes()
            if self.data_axes is not None:
                self.data_axes.set_data(self.data_rgb)
                self.data_axes.set_interpolation(self._interpolation)
            elif self.display_mode in ('data', 'overlay'):
                self.show_data()
            self.axes.figure.canvas.draw()

    def _update_class_rgb(self):
        if self.display_mode == 'overlay':
            self.class_rgb = np.ma.array(self.classes, mask=(self.classes==0))
        else:
            self.class_rgb = np.array(self.classes)
        
    def set_display_mode(self, mode):
        '''`mode` must be one of ("data", "classes", "overlay").'''
        if mode not in ('data', 'classes', 'overlay'):
            raise ValueError('Invalid display mode: ' + repr(mode))
        self.display_mode = mode

        show_data = mode in ('data', 'overlay')
        if self.data_axes is not None:
            self.data_axes.set_visible(show_data)

        show_classes = mode in ('classes', 'overlay')
        if self.classes is not None and self.class_axes is None:
            # Class data values were just set
            self.show_classes()
        if self.class_axes is not None:
            self.class_axes.set_visible(show_classes)
            if mode == 'classes':
                self.class_axes.set_alpha(1)
            else:
                self.class_axes.set_alpha(self._class_alpha)
        self.refresh()

    @property
    def class_alpha(self):
        '''alpha transparency for the class overlay.'''
        return self._class_alpha

    @class_alpha.setter
    def class_alpha(self, alpha):
        if alpha < 0 or alpha > 1:
            raise ValueError('Alpha value must be in range [0, 1].')
        self._class_alpha = alpha
        if self.class_axes is not None:
            self.class_axes.set_alpha(alpha)
        if self.is_shown:
            self.refresh()

    @property
    def interpolation(self):
        '''matplotlib pixel interpolation to use in the image display.'''
        return self._interpolation

    @interpolation.setter
    def interpolation(self, interpolation):
        if interpolation == self._interpolation:
            return
        self._interpolation = interpolation
        if not self.is_shown:
            return
        if self.data_axes is not None:
            self.data_axes.set_interpolation(interpolation)
        if self.class_axes is not None:
            self.class_axes.set_interpolation(interpolation)
        self.refresh()

    def set_title(self, s):
        if self.is_shown:
            self.axes.set_title(s)
            self.refresh()

    def open_zoom(self, center=None, size=None):
        '''Opens a separate window with a zoomed view.
        If a ctrl-lclick event occurs in the original view, the zoomed window
        will pan to the location of the click event.

        Arguments:

            `center` (two-tuple of int):

                Initial (row, col) of the zoomed view.

            `size` (int):

                Width and height (in source image pixels) of the initial
                zoomed view.

        Returns:

        A new ImageView object for the zoomed view.
        '''
        from spectral import settings
        import matplotlib.pyplot as plt
        if size is None:
            size = settings.imshow_zoom_pixel_width
        (nrows, ncols) = self._image_shape
        fig_kwargs = {}
        if settings.imshow_zoom_figure_width is not None:
            width = settings.imshow_zoom_figure_width
            fig_kwargs['figsize'] = (width, width)
        fig = plt.figure(**fig_kwargs)

        view = ImageView(source=self.source)
        view.set_data(self.data, self.bands, **self.rgb_kwargs)
        view.set_classes(self.classes, self.class_colors)
        view.imshow_data_kwargs = self.imshow_data_kwargs.copy()
        kwargs = {'extent': (-0.5, ncols - 0.5, nrows - 0.5, -0.5)}
        view.imshow_data_kwargs.update(kwargs)
        view.imshow_class_kwargs = self.imshow_class_kwargs.copy()
        view.imshow_class_kwargs.update(kwargs)
        view.set_display_mode(self.display_mode)
        view.callbacks_common = self.callbacks_common
        view.show(fignum=fig.number, mode=self.display_mode)
        view.axes.set_xlim(0, size)
        view.axes.set_ylim(size, 0)
        view.interpolation = 'nearest'
        if center is not None:
            view.pan_to(*center)
        view.cb_parent_pan = ParentViewPanCallback(view, self)
        view.cb_parent_pan.connect()
        return view

    def pan_to(self, row, col):
        '''Centers view on pixel coordinate (row, col).'''
        if self.axes is None:
            raise Exception('Cannot pan image until it is shown.')
        (xmin, xmax) = self.axes.get_xlim()
        (ymin, ymax) = self.axes.get_ylim()
        xrange_2 = (xmax - xmin) / 2.0
        yrange_2 = (ymax - ymin) / 2.0
        self.axes.set_xlim(col - xrange_2, col + xrange_2)
        self.axes.set_ylim(row - yrange_2, row + yrange_2)
        self.axes.figure.canvas.draw()

    def zoom(self, scale):
        '''Zooms view in/out (`scale` > 1 zooms in).'''
        (xmin, xmax) = self.axes.get_xlim()
        (ymin, ymax) = self.axes.get_ylim()
        x = (xmin + xmax) / 2.0
        y = (ymin + ymax) / 2.0
        dx = (xmax - xmin) / 2.0 / scale
        dy = (ymax - ymin) / 2.0 / scale

        self.axes.set_xlim(x - dx, x + dx)
        self.axes.set_ylim(y - dy, y + dy)
        self.refresh()


    def format_coord(self, x, y):
        '''Formats pixel coordinate string displayed in the window.'''
        (nrows, ncols) = self._image_shape
        if x < -0.5 or x > ncols - 0.5 or y < -0.5 or y > nrows - 0.5:
            return ""
        (r, c) = xy_to_rowcol(x, y)
        s = 'pixel=[%d,%d]' % (r, c)
        if self.classes is not None:
            try:
                s += ' class=%d' % self.classes[r, c]
            except:
                pass
        return s

    def __str__(self):
        meta = self.data_rgb_meta
        s = 'ImageView object:\n'
        if 'bands' in meta:
            s += '  {0:<20}:  {1}\n'.format("Display bands", meta['bands'])
        if self.interpolation == None:
            interp = "<default>"
        else:
            interp = self.interpolation
        s += '  {0:<20}:  {1}\n'.format("Interpolation", interp)
        if 'rgb range' in meta:
            s += '  {0:<20}:\n'.format("RGB data limits")
            for (c, r) in zip('RGB', meta['rgb range']):
                s += '    {0}: {1}\n'.format(c, str(r))
        return s

    def __repr__(self):
        return str(self)
Пример #7
0
class NDWindow(wx.Frame):
    '''A widow class for displaying N-dimensional data points.'''
    def __init__(self, data, parent, id, *args, **kwargs):
        global DEFAULT_WIN_SIZE
        self.kwargs = kwargs
        self.size = kwargs.get('size', DEFAULT_WIN_SIZE)
        self.title = kwargs.get('title', 'ND Window')

        #
        # Forcing a specific style on the window.
        #   Should this include styles passed?
        style = wx.DEFAULT_FRAME_STYLE | wx.NO_FULL_REPAINT_ON_RESIZE
        super(NDWindow, self).__init__(parent, id,
                                       self.title, wx.DefaultPosition,
                                       wx.Size(*self.size), style, self.title)

        self.gl_initialized = False
        attribs = (glcanvas.WX_GL_RGBA, glcanvas.WX_GL_DOUBLEBUFFER,
                   glcanvas.WX_GL_DEPTH_SIZE, settings.WX_GL_DEPTH_SIZE)
        self.canvas = glcanvas.GLCanvas(self, attribList=attribs)
        self.canvas.context = wx.glcanvas.GLContext(self.canvas)

        self._have_glut = False
        self.clear_color = (0, 0, 0, 0)
        self.show_axes_tf = True
        self.point_size = 1.0
        self._show_unassigned = True
        self._refresh_display_lists = False
        self._click_tolerance = 1
        self._display_commands = []
        self._selection_box = None
        self._rgba_indices = None
        self.mouse_panning = False
        self.win_pos = (100, 100)
        self.fovy = 60.
        self.znear = 0.1
        self.zfar = 10.0
        self.target_pos = [0.0, 0.0, 0.0]
        self.camera_pos_rtp = [7.0, 45.0, 30.0]
        self.up = [0.0, 0.0, 1.0]

        self.quadrant_mode = None
        self.mouse_handler = MouseHandler(self)

        # Set the event handlers.
        self.canvas.Bind(wx.EVT_ERASE_BACKGROUND, self.on_erase_background)
        self.Bind(wx.EVT_SIZE, self.on_resize)
        self.canvas.Bind(wx.EVT_PAINT, self.on_paint)
        self.canvas.Bind(wx.EVT_LEFT_DOWN, self.mouse_handler.left_down)
        self.canvas.Bind(wx.EVT_LEFT_UP, self.mouse_handler.left_up)
        self.canvas.Bind(wx.EVT_MOTION, self.mouse_handler.motion)
        self.canvas.Bind(wx.EVT_CHAR, self.on_char)
        self.canvas.Bind(wx.EVT_RIGHT_DOWN, self.right_click)
        self.canvas.Bind(wx.EVT_CLOSE, self.on_event_close)

        self.data = data
        self.classes = kwargs.get('classes', np.zeros(data.shape[:-1], np.int))
        self.features = kwargs.get('features', list(range(6)))
        self.labels = kwargs.get('labels', list(range(data.shape[-1])))
        self.max_menu_class = int(np.max(self.classes.ravel() + 1))

        from matplotlib.cbook import CallbackRegistry
        self.callbacks = CallbackRegistry()

    def on_event_close(self, event=None):
        pass

    def right_click(self, event):
        self.canvas.SetCurrent(self.canvas.context)
        self.canvas.PopupMenu(MouseMenu(self), event.GetPosition())

    def add_display_command(self, cmd):
        '''Adds a command to be called next time `display` is run.'''
        self._display_commands.append(cmd)

    def reset_view_geometry(self):
        '''Sets viewing geometry to the default view.'''
        # All grid points will be adjusted to the range [0,1] so this
        # is a reasonable center coordinate for the scene
        self.target_pos = np.array([0.0, 0.0, 0.0])

        # Specify the camera location in spherical polar coordinates relative
        # to target_pos.
        self.camera_pos_rtp = [2.5, 45.0, 30.0]

    def set_data(self, data, **kwargs):
        '''Associates N-D point data with the window.
        ARGUMENTS:
            data (numpy.ndarray):
                An RxCxB array of data points to display.
        KEYWORD ARGUMENTS:
            classes (numpy.ndarray):
                An RxC array of integer class labels (zeros means unassigned).
            features (list):
                Indices of feautures to display in the octant (see
                NDWindow.set_octant_display_features for description).
        '''
        import OpenGL.GL as gl
        try:
            from OpenGL.GL import glGetIntegerv
        except:
            from OpenGL.GL.glget import glGetIntegerv

        classes = kwargs.get('classes', None)
        features = kwargs.get('features', list(range(6)))
        if self.data.shape[2] < 6:
            features = features[:3]
            self.quadrant_mode == 'single'

        # Scale the data set to span an octant

        data2d = np.array(data.reshape((-1, data.shape[-1])))
        mins = np.min(data2d, axis=0)
        maxes = np.max(data2d, axis=0)
        denom = (maxes - mins).astype(float)
        denom = np.where(denom > 0, denom, 1.0)
        self.data = (data2d - mins) / denom
        self.data.shape = data.shape

        self.palette = spy_colors.astype(float) / 255.
        self.palette[0] = np.array([1.0, 1.0, 1.0])
        self.colors = self.palette[self.classes.ravel()].reshape(
            self.data.shape[:2] + (3, ))
        self.colors = (self.colors * 255).astype('uint8')
        colors = np.ones((self.colors.shape[:-1]) + (4, ), 'uint8')
        colors[:, :, :-1] = self.colors
        self.colors = colors
        self._refresh_display_lists = True
        self.set_octant_display_features(features)

        # Determine the bit masks to use when using RGBA components for
        # identifying pixel IDs.
        components = [
            gl.GL_RED_BITS, gl.GL_GREEN_BITS, gl.GL_GREEN_BITS,
            gl.GL_ALPHA_BITS
        ]
        self._rgba_bits = [min(8, glGetIntegerv(i)) for i in components]
        self._low_bits = [min(8, 8 - self._rgba_bits[i]) for i in range(4)]
        self._rgba_masks = \
            [(2**self._rgba_bits[i] - 1) << (8 - self._rgba_bits[i])
             for i in range(4)]

        # Determine how many times the scene will need to be rendered in the
        # background to extract the pixel's row/col index.

        N = self.data.shape[0] * self.data.shape[1]
        if N > 2**sum(self._rgba_bits):
            raise Exception(
                'Insufficient color bits (%d) for N-D window display' %
                sum(self._rgba_bits))
        self.reset_view_geometry()

    def set_octant_display_features(self, features):
        '''Specifies features to be displayed in each 3-D coordinate octant.
        `features` can be any of the following:
        A length-3 list of integer feature IDs:
            In this case, the data points will be displayed in the positive
            x,y,z octant using features associated with the 3 integers.
        A length-6 list if integer feature IDs:
            In this case, each integer specifies a single feature index to be
            associated with the coordinate semi-axes x, y, z, -x, -y, and -z
            (in that order).  Each octant will display data points using the
            features associated with the 3 semi-axes for that octant.
        A length-8 list of length-3 lists of integers:
            In this case, each length-3 list specfies the features to be
            displayed in a single octants (the same semi-axis can be associated
            with different features in different octants).  Octants are ordered
            starting with the postive x,y,z octant and procede counterclockwise
            around the z-axis, then procede similarly around the negative half
            of the z-axis.  An octant triplet can be specified as None instead
            of a list, in which case nothing will be rendered in that octant.
        '''
        if features is None:
            features = list(range(6))
        if len(features) == 3:
            self.octant_features = [features] + [None] * 7
            new_quadrant_mode = 'single'
            self.target_pos = np.array([0.5, 0.5, 0.5])
        elif len(features) == 6:
            self.octant_features = create_mirrored_octants(features)
            new_quadrant_mode = 'mirrored'
            self.target_pos = np.array([0.0, 0.0, 0.0])
        else:
            self.octant_features = features
            new_quadrant_mode = 'independent'
            self.target_pos = np.array([0.0, 0.0, 0.0])
        if new_quadrant_mode != self.quadrant_mode:
            print('Setting quadrant display mode to %s.' % new_quadrant_mode)
            self.quadrant_mode = new_quadrant_mode
        self._refresh_display_lists = True

    def create_display_lists(self, npass=-1, **kwargs):
        '''Creates or updates the display lists for image data.
        ARGUMENTS:
            `npass` (int):
                When defaulted to -1, the normal image data display lists are
                created.  When >=0, `npass` represents the rendering pass for
                identifying image pixels in the scene by their unique colors.
        KEYWORD ARGS:
            `indices` (list of ints):
                 An optional list of N-D image pixels to display.
        '''
        import OpenGL.GL as gl
        gl.glEnableClientState(gl.GL_COLOR_ARRAY)
        gl.glEnableClientState(gl.GL_VERTEX_ARRAY)

        gl.glPointSize(self.point_size)
        gl.glColorPointerub(self.colors)

        (R, C, B) = self.data.shape

        indices = kwargs.get('indices', None)
        if indices is None:
            indices = np.arange(R * C)
            if not self._show_unassigned:
                indices = indices[self.classes.ravel() != 0]
            self._display_indices = indices

        # RGB pixel indices for selecting pixels with the mouse
        gl.glPointSize(self.point_size)
        if npass < 0:
            # Colors are associated with image pixel classes.
            gl.glColorPointerub(self.colors)
        else:
            if self._rgba_indices is None:
                # Generate unique colors that correspond to each pixel's ID
                # so that the color can be used to identify the pixel.
                color_indices = np.arange(R * C)
                rgba = np.zeros((len(color_indices), 4), 'uint8')
                for i in range(4):
                    shift = sum(self._rgba_bits[0:i]) - self._low_bits[i]
                    if shift > 0:
                        rgba[:, i] = (
                            color_indices >> shift) & self._rgba_masks[i]
                    else:
                        rgba[:, i] = (color_indices << self._low_bits[i]) \
                            & self._rgba_masks[i]
                self._rgba_indices = rgba
            gl.glColorPointerub(self._rgba_indices)

        # Generate a display list for each octant of the 3-D window.

        for (i, octant) in enumerate(self.octant_features):
            if octant is not None:
                data = np.take(self.data, octant, axis=2).reshape((-1, 3))
                data *= octant_coeffs[i]
                gl.glVertexPointerf(data)
                gl.glNewList(self.gllist_id + i + 1, gl.GL_COMPILE)
                gl.glDrawElementsui(gl.GL_POINTS, indices)
                gl.glEndList()
            else:
                # Create an empty draw list
                gl.glNewList(self.gllist_id + i + 1, gl.GL_COMPILE)
                gl.glEndList()

        self.create_axes_list()
        self._refresh_display_lists = False

    def randomize_features(self):
        '''Randomizes data features displayed using current display mode.'''
        ids = list(range(self.data.shape[2]))
        if self.quadrant_mode == 'single':
            features = random_subset(ids, 3)
        elif self.quadrant_mode == 'mirrored':
            features = random_subset(ids, 6)
        else:
            features = [random_subset(ids, 3) for i in range(8)]
        print('New feature IDs:')
        pprint(np.array(features))
        self.set_octant_display_features(features)

    def set_features(self, features, mode='single'):
        if mode == 'single':
            if len(features) != 3:
                raise Exception(
                    'Expected 3 feature indices for "single" mode.')
        elif mode == 'mirrored':
            if len(features) != 6:
                raise Exception(
                    'Expected 6 feature indices for "mirrored" mode.')
        elif mode == 'independent':
            if len(features) != 8:
                raise Exception('Expected 8 3-tuples of feature indices for'
                                '"independent" mode.')
        else:
            raise Exception('Unrecognized feature mode: %s.' % str(mode))
        print('New feature IDs:')
        pprint(np.array(features))
        self.set_octant_display_features(features)
        self.Refresh()

    def draw_box(self, x0, y0, x1, y1):
        '''Draws a selection box in the 3-D window.
        Coordinates are with respect to the lower left corner of the window.
        '''
        import OpenGL.GL as gl
        gl.glMatrixMode(gl.GL_PROJECTION)
        gl.glLoadIdentity()
        gl.glOrtho(0.0, self.size[0], 0.0, self.size[1], -0.01, 10.0)

        gl.glLineStipple(1, 0xF00F)
        gl.glEnable(gl.GL_LINE_STIPPLE)
        gl.glLineWidth(1.0)
        gl.glColor3f(1.0, 1.0, 1.0)
        gl.glBegin(gl.GL_LINE_LOOP)
        gl.glVertex3f(x0, y0, 0.0)
        gl.glVertex3f(x1, y0, 0.0)
        gl.glVertex3f(x1, y1, 0.0)
        gl.glVertex3f(x0, y1, 0.0)
        gl.glEnd()
        gl.glDisable(gl.GL_LINE_STIPPLE)
        gl.glFlush()

        self.resize(*self.size)

    def on_paint(self, event):
        '''Renders the entire scene.'''
        import OpenGL.GL as gl
        import OpenGL.GLU as glu

        self.canvas.SetCurrent(self.canvas.context)
        if not self.gl_initialized:
            self.initgl()
            self.gl_initialized = True
            self.print_help()
            self.resize(*self.size)

        gl.glMatrixMode(gl.GL_MODELVIEW)
        gl.glLoadIdentity()
        gl.glClear(gl.GL_COLOR_BUFFER_BIT | gl.GL_DEPTH_BUFFER_BIT)

        while len(self._display_commands) > 0:
            self._display_commands.pop(0)()

        if self._refresh_display_lists:
            self.create_display_lists()

        gl.glPushMatrix()
        # camera_pos_rtp is relative to target position. To get the absolute
        # camera position, we need to add the target position.
        camera_pos_xyz = np.array(rtp_to_xyz(*self.camera_pos_rtp)) \
            + self.target_pos
        glu.gluLookAt(*(list(camera_pos_xyz) + list(self.target_pos) +
                        self.up))

        if self.show_axes_tf:
            gl.glCallList(self.gllist_id)

        self.draw_data_set()

        gl.glPopMatrix()
        gl.glFlush()

        if self._selection_box is not None:
            self.draw_box(*self._selection_box)

        self.SwapBuffers()
        event.Skip()

    def post_reassign_selection(self, new_class):
        '''Reassigns pixels in selection box during the next rendering loop.
        ARGUMENT:
            `new_class` (int):
                The class to which the pixels in the box will be assigned.
        '''
        if self._selection_box is None:
            msg = 'Bounding box is not selected. Hold SHIFT and click & ' + \
                  'drag with the left\nmouse button to select a region.'
            print(msg)
            return 0
        self.add_display_command(lambda: self.reassign_selection(new_class))
        self.canvas.Refresh()
        return 0

    def reassign_selection(self, new_class):
        '''Reassigns pixels in the selection box to the specified class.
        This method should only be called from the `display` method. Pixels are
        reassigned by identifying each pixel in the 3D display by their unique
        color, then reassigning them. Since pixels can block others in the
        z-buffer, this method iteratively reassigns pixels by removing any
        reassigned pixels from the display list, then reassigning again,
        repeating until there are no more pixels in the selction box.
        '''
        nreassigned_tot = 0
        i = 1
        print('Reassigning points', end=' ')
        while True:
            indices = np.array(self._display_indices)
            classes = np.array(self.classes.ravel()[indices])
            indices = indices[np.where(classes != new_class)]
            ids = self.get_points_in_selection_box(indices=indices)
            cr = self.classes.ravel()
            nreassigned = np.sum(cr[ids] != new_class)
            nreassigned_tot += nreassigned
            cr[ids] = new_class
            new_color = np.zeros(4, 'uint8')
            new_color[:3] = (np.array(self.palette[new_class]) *
                             255).astype('uint8')
            self.colors.reshape((-1, 4))[ids] = new_color
            self.create_display_lists()
            if len(ids) == 0:
                break


#           print 'Pass %d: %d points reassigned to class %d.' \
#                 % (i, nreassigned, new_class)
            print('.', end=' ')
            i += 1
        print('\n%d points were reasssigned to class %d.' \
              % (nreassigned_tot, new_class))
        self._selection_box = None
        if nreassigned_tot > 0 and new_class == self.max_menu_class:
            self.max_menu_class += 1

        if nreassigned_tot > 0:
            event = SpyMplEvent('spy_classes_modified')
            event.classes = self.classes
            event.nchanged = nreassigned_tot
            self.callbacks.process('spy_classes_modified', event)

        return nreassigned_tot

    def get_points_in_selection_box(self, **kwargs):
        '''Returns pixel IDs of all points in the current selection box.
        KEYWORD ARGS:
            `indices` (ndarray of ints):
                An alternate set of N-D image pixels to display.

        Pixels are identified by performing a background rendering loop wherein
        each pixel is rendered with a unique color. Then, glReadPixels is used
        to read colors of pixels in the current selection box.
        '''
        import OpenGL.GL as gl
        indices = kwargs.get('indices', None)
        point_size_temp = self.point_size
        self.point_size = kwargs.get('point_size', 1)

        xsize = self._selection_box[2] - self._selection_box[0] + 1
        ysize = self._selection_box[3] - self._selection_box[1] + 1
        ids = np.zeros(xsize * ysize, int)

        self.create_display_lists(0, indices=indices)
        self.render_rgb_indexed_colors()
        gl.glPixelStorei(gl.GL_UNPACK_ALIGNMENT, 1)
        pixels = gl.glReadPixelsub(self._selection_box[0],
                                   self._selection_box[1], xsize, ysize,
                                   gl.GL_RGBA)
        pixels = np.frombuffer(pixels, dtype=np.uint8).reshape(
            (ysize, xsize, 4))
        for i in range(4):
            component = pixels[:, :, i].reshape((xsize * ysize,)) \
                & self._rgba_masks[i]
            shift = (sum(self._rgba_bits[0:i]) - self._low_bits[i])
            if shift > 0:
                ids += component.astype(int) << shift
            else:
                ids += component.astype(int) >> (-shift)

        points = ids[ids > 0]

        self.point_size = point_size_temp
        gl.glMatrixMode(gl.GL_MODELVIEW)
        gl.glLoadIdentity()
        gl.glClear(gl.GL_COLOR_BUFFER_BIT | gl.GL_DEPTH_BUFFER_BIT)
        self._refresh_display_lists = True

        return points

    def get_pixel_info(self, x, y, **kwargs):
        '''Prints row/col of the pixel at the given raster position.
        ARGUMENTS:
            `x`, `y`: (int):
                The pixel's coordinates relative to the lower left corner.
        '''
        self._selection_box = (x, y, x, y)
        ids = self.get_points_in_selection_box(point_size=self.point_size)
        for id in ids:
            if id > 0:
                rc = self.index_to_image_row_col(id)
                print('Pixel %d %s has class %s.' % (id, rc, self.classes[rc]))
        return

    def render_rgb_indexed_colors(self, **kwargs):
        '''Draws scene in the background buffer to extract mouse click info'''
        import OpenGL.GL as gl
        import OpenGL.GLU as glu
        gl.glMatrixMode(gl.GL_MODELVIEW)
        gl.glLoadIdentity()
        gl.glClear(gl.GL_COLOR_BUFFER_BIT | gl.GL_DEPTH_BUFFER_BIT)

        # camera_pos_rtp is relative to the target position. To get the
        # absolute camera position, we need to add the target position.
        gl.glPushMatrix()
        camera_pos_xyz = np.array(rtp_to_xyz(*self.camera_pos_rtp)) \
            + self.target_pos
        glu.gluLookAt(*(list(camera_pos_xyz) + list(self.target_pos) +
                        self.up))
        self.draw_data_set()
        gl.glPopMatrix()
        gl.glFlush()

    def index_to_image_row_col(self, index):
        '''Converts the unraveled pixel ID to row/col of the N-D image.'''
        rowcol = (index // self.data.shape[1], index % self.data.shape[1])
        return rowcol

    def draw_data_set(self):
        '''Draws the N-D data set in the scene.'''
        import OpenGL.GL as gl
        for i in range(1, 9):
            gl.glCallList(self.gllist_id + i)

    def create_axes_list(self):
        '''Creates display lists to render unit length x,y,z axes.'''
        import OpenGL.GL as gl
        gl.glNewList(self.gllist_id, gl.GL_COMPILE)
        gl.glBegin(gl.GL_LINES)
        gl.glColor3f(1.0, 0.0, 0.0)
        gl.glVertex3f(0.0, 0.0, 0.0)
        gl.glVertex3f(1.0, 0.0, 0.0)
        gl.glColor3f(0.0, 1.0, 0.0)
        gl.glVertex3f(0.0, 0.0, 0.0)
        gl.glVertex3f(0.0, 1.0, 0.0)
        gl.glColor3f(-.0, 0.0, 1.0)
        gl.glVertex3f(0.0, 0.0, 0.0)
        gl.glVertex3f(0.0, 0.0, 1.0)

        gl.glColor3f(1.0, 1.0, 1.0)
        gl.glVertex3f(0.0, 0.0, 0.0)
        gl.glVertex3f(-1.0, 0.0, 0.0)
        gl.glVertex3f(0.0, 0.0, 0.0)
        gl.glVertex3f(0.0, -1.0, 0.0)
        gl.glVertex3f(0.0, 0.0, 0.0)
        gl.glVertex3f(0.0, 0.0, -1.0)
        gl.glEnd()

        def label_axis(x, y, z, label):
            gl.glRasterPos3f(x, y, z)
            glut.glutBitmapString(glut.GLUT_BITMAP_HELVETICA_18, str(label))

        def label_axis_for_feature(x, y, z, feature_ind):
            feature = self.octant_features[feature_ind[0]][feature_ind[1]]
            label_axis(x, y, z, self.labels[feature])

        if self._have_glut:
            try:
                import OpenGL.GLUT as glut
                if bool(glut.glutBitmapString):
                    if self.quadrant_mode == 'independent':
                        label_axis(1.05, 0.0, 0.0, 'x')
                        label_axis(0.0, 1.05, 0.0, 'y')
                        label_axis(0.0, 0.0, 1.05, 'z')
                    elif self.quadrant_mode == 'mirrored':
                        label_axis_for_feature(1.05, 0.0, 0.0, (0, 0))
                        label_axis_for_feature(0.0, 1.05, 0.0, (0, 1))
                        label_axis_for_feature(0.0, 0.0, 1.05, (0, 2))
                        label_axis_for_feature(-1.05, 0.0, 0.0, (6, 0))
                        label_axis_for_feature(0.0, -1.05, 0.0, (6, 1))
                        label_axis_for_feature(0.0, 0.0, -1.05, (6, 2))
                    else:
                        label_axis_for_feature(1.05, 0.0, 0.0, (0, 0))
                        label_axis_for_feature(0.0, 1.05, 0.0, (0, 1))
                        label_axis_for_feature(0.0, 0.0, 1.05, (0, 2))
            except:
                pass
        gl.glEndList()

    def GetGLExtents(self):
        """Get the extents of the OpenGL canvas."""
        return

    def SwapBuffers(self):
        """Swap the OpenGL buffers."""
        self.canvas.SwapBuffers()

    def on_erase_background(self, event):
        """Process the erase background event."""
        pass  # Do nothing, to avoid flashing on MSWin

    def initgl(self):
        '''App-specific initialization for after GLUT has been initialized.'''
        import OpenGL.GL as gl
        self.gllist_id = gl.glGenLists(9)
        gl.glEnableClientState(gl.GL_VERTEX_ARRAY)
        gl.glEnableClientState(gl.GL_COLOR_ARRAY)
        gl.glDisable(gl.GL_LIGHTING)
        gl.glDisable(gl.GL_TEXTURE_2D)
        gl.glDisable(gl.GL_FOG)
        gl.glDisable(gl.GL_COLOR_MATERIAL)
        gl.glEnable(gl.GL_DEPTH_TEST)
        gl.glShadeModel(gl.GL_FLAT)
        self.set_data(self.data, classes=self.classes, features=self.features)

        try:
            import OpenGL.GLUT as glut
            glut.glutInit()
            self._have_glut = True
        except:
            pass

    def on_resize(self, event):
        '''Process the resize event.'''

        # For wx versions 2.9.x, GLCanvas.GetContext() always returns None,
        # whereas 2.8.x will return the context so test for both versions.

        if wx.VERSION >= (2, 9) or self.canvas.GetContext():
            self.canvas.SetCurrent(self.canvas.context)
            # Make sure the frame is shown before calling SetCurrent.
            self.Show()
            size = event.GetSize()
            self.resize(size.width, size.height)
            self.canvas.Refresh(False)
        event.Skip()

    def resize(self, width, height):
        """Reshape the OpenGL viewport based on dimensions of the window."""
        import OpenGL.GL as gl
        import OpenGL.GLU as glu
        self.size = (width, height)
        gl.glViewport(0, 0, width, height)
        gl.glMatrixMode(gl.GL_PROJECTION)
        gl.glLoadIdentity()
        glu.gluPerspective(self.fovy,
                           float(width) / height, self.znear, self.zfar)

        gl.glMatrixMode(gl.GL_MODELVIEW)
        gl.glLoadIdentity()

    def on_char(self, event):
        '''Callback function for when a keyboard button is pressed.'''
        key = chr(event.GetKeyCode())

        # See `print_help` method for explanation of keybinds.
        if key == 'a':
            self.show_axes_tf = not self.show_axes_tf
        elif key == 'c':
            self.view_class_image()
        elif key == 'd':
            if self.data.shape[2] < 6:
                print('Only single-quadrant mode is supported for %d features.' % \
                      self.data.shape[2])
                return
            if self.quadrant_mode == 'single':
                self.quadrant_mode = 'mirrored'
            elif self.quadrant_mode == 'mirrored':
                self.quadrant_mode = 'independent'
            else:
                self.quadrant_mode = 'single'
            print('Setting quadrant display mode to %s.' % self.quadrant_mode)
            self.randomize_features()
        elif key == 'f':
            self.randomize_features()
        elif key == 'h':
            self.print_help()
        elif key == 'm':
            self.mouse_panning = not self.mouse_panning
        elif key == 'p':
            self.point_size += 1
            self._refresh_display_lists = True
        elif key == 'P':
            self.point_size = max(self.point_size - 1, 1.0)
            self._refresh_display_lists = True
        elif key == 'q':
            self.on_event_close()
            self.Close(True)
        elif key == 'r':
            self.reset_view_geometry()
        elif key == 'u':
            self._show_unassigned = not self._show_unassigned
            print('SHOW UNASSIGNED =', self._show_unassigned)
            self._refresh_display_lists = True

        self.canvas.Refresh()

    def update_window_title(self):
        '''Prints current file name and current point color to window title.'''
        s = 'SPy N-D Data Set'
        glutSetWindowTitle(s)

    def get_proxy(self):
        '''Returns a proxy object to access data from the window.'''
        return NDWindowProxy(self)

    def view_class_image(self, *args, **kwargs):
        '''Opens a dynamic raster image of class values.

        The class IDs displayed are those currently associated with the ND
        window. `args` and `kwargs` are additional arguments passed on to the
        `ImageView` constructor. Return value is the ImageView object.
        '''
        view = ImageView(classes=self.classes, *args, **kwargs)
        view.callbacks_common = self.callbacks
        view.show()
        return view

    def print_help(self):
        '''Prints a list of accepted keyboard/mouse inputs.'''
        print('''Mouse functions:
---------------
Left-click & drag       -->     Rotate viewing geometry (or pan)
CTRL+Left-click & drag  -->     Zoom viewing geometry
CTRL+SHIFT+Left-click   -->     Print image row/col and class of selected pixel
SHIFT+Left-click & drag -->     Define selection box in the window
Right-click             -->     Open GLUT menu for pixel reassignment

Keyboard functions:
-------------------
a       -->     Toggle axis display
c       -->     View dynamic raster image of class values
d       -->     Cycle display mode between single-quadrant, mirrored octants,
                and independent octants (display will not change until features
                are randomzed again)
f       -->     Randomize features displayed
h       -->     Print this help message
m       -->     Toggle mouse function between rotate/zoom and pan modes
p/P     -->     Increase/Decrease the size of displayed points
q       -->     Exit the application
r       -->     Reset viewing geometry
u       -->     Toggle display of unassigned points (points with class == 0)
''')
Пример #8
0
class NDWindow(wx.Frame):
    '''A widow class for displaying N-dimensional data points.'''

    def __init__(self, data, parent, id, *args, **kwargs):
        from spectral import settings
        global DEFAULT_WIN_SIZE
        self.kwargs = kwargs
        self.size = kwargs.get('size', DEFAULT_WIN_SIZE)
        self.title = kwargs.get('title', 'ND Window')

        #
        # Forcing a specific style on the window.
        #   Should this include styles passed?
        style = wx.DEFAULT_FRAME_STYLE | wx.NO_FULL_REPAINT_ON_RESIZE
        super(NDWindow, self).__init__(parent, id, self.title,
                                       wx.DefaultPosition,
                                       wx.Size(*self.size),
                                       style,
                                       self.title)

        self.gl_initialized = False
        attribs = (glcanvas.WX_GL_RGBA,
                   glcanvas.WX_GL_DOUBLEBUFFER,
                   glcanvas.WX_GL_DEPTH_SIZE, settings.WX_GL_DEPTH_SIZE)
        self.canvas = glcanvas.GLCanvas(self, attribList=attribs)
        self.canvas.context = wx.glcanvas.GLContext(self.canvas)

        self._have_glut = False
        self.clear_color = (0, 0, 0, 0)
        self.show_axes_tf = True
        self.point_size = 1.0
        self._show_unassigned = True
        self._refresh_display_lists = False
        self._click_tolerance = 1
        self._display_commands = []
        self._selection_box = None
        self._rgba_indices = None
        self.mouse_panning = False
        self.win_pos = (100, 100)
        self.fovy = 60.
        self.znear = 0.1
        self.zfar = 10.0
        self.target_pos = [0.0, 0.0, 0.0]
        self.camera_pos_rtp = [7.0, 45.0, 30.0]
        self.up = [0.0, 0.0, 1.0]

        self.quadrant_mode = None
        self.mouse_handler = MouseHandler(self)

        # Set the event handlers.
        self.canvas.Bind(wx.EVT_ERASE_BACKGROUND, self.on_erase_background)
        self.Bind(wx.EVT_SIZE, self.on_resize)
        self.canvas.Bind(wx.EVT_PAINT, self.on_paint)
        self.canvas.Bind(wx.EVT_LEFT_DOWN, self.mouse_handler.left_down)
        self.canvas.Bind(wx.EVT_LEFT_UP, self.mouse_handler.left_up)
        self.canvas.Bind(wx.EVT_MOTION, self.mouse_handler.motion)
        self.canvas.Bind(wx.EVT_CHAR, self.on_char)
        self.canvas.Bind(wx.EVT_RIGHT_DOWN, self.right_click)
        self.canvas.Bind(wx.EVT_CLOSE, self.on_event_close)

        self.data = data
        self.classes = kwargs.get('classes',
                                  np.zeros(data.shape[:-1], np.int))
        self.features = kwargs.get('features', list(range(6)))
        self.max_menu_class = int(np.max(self.classes.ravel() + 1))

        from matplotlib.cbook import CallbackRegistry
        self.callbacks = CallbackRegistry()


    def on_event_close(self, event=None):
        pass

    def right_click(self, event):
        self.canvas.SetCurrent(self.canvas.context)
        self.canvas.PopupMenu(MouseMenu(self), event.GetPosition())

    def add_display_command(self, cmd):
        '''Adds a command to be called next time `display` is run.'''
        self._display_commands.append(cmd)

    def reset_view_geometry(self):
        '''Sets viewing geometry to the default view.'''
        # All grid points will be adjusted to the range [0,1] so this
        # is a reasonable center coordinate for the scene
        self.target_pos = np.array([0.0, 0.0, 0.0])

        # Specify the camera location in spherical polar coordinates relative
        # to target_pos.
        self.camera_pos_rtp = [2.5, 45.0, 30.0]

    def set_data(self, data, **kwargs):
        '''Associates N-D point data with the window.
        ARGUMENTS:
            data (numpy.ndarray):
                An RxCxB array of data points to display.
        KEYWORD ARGUMENTS:
            classes (numpy.ndarray):
                An RxC array of integer class labels (zeros means unassigned).
            features (list):
                Indices of feautures to display in the octant (see
                NDWindow.set_octant_display_features for description).
        '''
        import OpenGL.GL as gl
        try:
            from OpenGL.GL import glGetIntegerv
        except:
            from OpenGL.GL.glget import glGetIntegerv

        classes = kwargs.get('classes', None)
        features = kwargs.get('features', list(range(6)))
        if self.data.shape[2] < 6:
            features = features[:3]
            self.quadrant_mode == 'single'

        # Scale the data set to span an octant

        data2d = np.array(data.reshape((-1, data.shape[-1])))
        mins = np.min(data2d, axis=0)
        maxes = np.max(data2d, axis=0)
        denom = (maxes - mins).astype(float)
        denom = np.where(denom > 0, denom, 1.0)
        self.data = (data2d - mins) / denom
        self.data.shape = data.shape

        self.palette = spy_colors.astype(float) / 255.
        self.palette[0] = np.array([1.0, 1.0, 1.0])
        self.colors = self.palette[self.classes.ravel()].reshape(
            self.data.shape[:2] + (3,))
        self.colors = (self.colors * 255).astype('uint8')
        colors = np.ones((self.colors.shape[:-1]) + (4,), 'uint8')
        colors[:, :, :-1] = self.colors
        self.colors = colors
        self._refresh_display_lists = True
        self.set_octant_display_features(features)

        # Determine the bit masks to use when using RGBA components for
        # identifying pixel IDs.
        components = [gl.GL_RED_BITS, gl.GL_GREEN_BITS,
                      gl.GL_GREEN_BITS, gl.GL_ALPHA_BITS]
        self._rgba_bits = [min(8, glGetIntegerv(i)) for i in components]
        self._low_bits = [min(8, 8 - self._rgba_bits[i]) for i in range(4)]
        self._rgba_masks = \
            [(2**self._rgba_bits[i] - 1) << (8 - self._rgba_bits[i])
             for i in range(4)]

        # Determine how many times the scene will need to be rendered in the
        # background to extract the pixel's row/col index.

        N = self.data.shape[0] * self.data.shape[1]
        if N > 2**sum(self._rgba_bits):
            raise Exception('Insufficient color bits (%d) for N-D window display'
                            % sum(self._rgba_bits))
        self.reset_view_geometry()

    def set_octant_display_features(self, features):
        '''Specifies features to be displayed in each 3-D coordinate octant.
        `features` can be any of the following:
        A length-3 list of integer feature IDs:
            In this case, the data points will be displayed in the positive
            x,y,z octant using features associated with the 3 integers.
        A length-6 list if integer feature IDs:
            In this case, each integer specifies a single feature index to be
            associated with the coordinate semi-axes x, y, z, -x, -y, and -z
            (in that order).  Each octant will display data points using the
            features associated with the 3 semi-axes for that octant.
        A length-8 list of length-3 lists of integers:
            In this case, each length-3 list specfies the features to be
            displayed in a single octants (the same semi-axis can be associated
            with different features in different octants).  Octants are ordered
            starting with the postive x,y,z octant and procede counterclockwise
            around the z-axis, then procede similarly around the negative half
            of the z-axis.  An octant triplet can be specified as None instead
            of a list, in which case nothing will be rendered in that octant.
        '''
        if features is None:
            features = list(range(6))
        if len(features) == 3:
            self.octant_features = [features] + [None] * 7
            new_quadrant_mode = 'single'
            self.target_pos = np.array([0.5, 0.5, 0.5])
        elif len(features) == 6:
            self.octant_features = create_mirrored_octants(features)
            new_quadrant_mode = 'mirrored'
            self.target_pos = np.array([0.0, 0.0, 0.0])
        else:
            self.octant_features = features
            new_quadrant_mode = 'independent'
            self.target_pos = np.array([0.0, 0.0, 0.0])
        if new_quadrant_mode != self.quadrant_mode:
            print('Setting quadrant display mode to %s.' % new_quadrant_mode)
            self.quadrant_mode = new_quadrant_mode
        self._refresh_display_lists = True

    def create_display_lists(self, npass=-1, **kwargs):
        '''Creates or updates the display lists for image data.
        ARGUMENTS:
            `npass` (int):
                When defaulted to -1, the normal image data display lists are
                created.  When >=0, `npass` represents the rendering pass for
                identifying image pixels in the scene by their unique colors.
        KEYWORD ARGS:
            `indices` (list of ints):
                 An optional list of N-D image pixels to display.
        '''
        import OpenGL.GL as gl
        gl.glEnableClientState(gl.GL_COLOR_ARRAY)
        gl.glEnableClientState(gl.GL_VERTEX_ARRAY)

        gl.glPointSize(self.point_size)
        gl.glColorPointerub(self.colors)

        (R, C, B) = self.data.shape

        indices = kwargs.get('indices', None)
        if indices is None:
            indices = np.arange(R * C)
            if not self._show_unassigned:
                indices = indices[self.classes.ravel() != 0]
            self._display_indices = indices

        # RGB pixel indices for selecting pixels with the mouse
        gl.glPointSize(self.point_size)
        if npass < 0:
            # Colors are associated with image pixel classes.
            gl.glColorPointerub(self.colors)
        else:
            if self._rgba_indices is None:
                # Generate unique colors that correspond to each pixel's ID
                # so that the color can be used to identify the pixel.
                color_indices = np.arange(R * C)
                rgba = np.zeros((len(color_indices), 4), 'uint8')
                for i in range(4):
                    shift = sum(self._rgba_bits[0:i]) - self._low_bits[i]
                    if shift > 0:
                        rgba[:, i] = (
                            color_indices >> shift) & self._rgba_masks[i]
                    else:
                        rgba[:, i] = (color_indices << self._low_bits[i]) \
                            & self._rgba_masks[i]
                self._rgba_indices = rgba
            gl.glColorPointerub(self._rgba_indices)

        # Generate a display list for each octant of the 3-D window.

        for (i, octant) in enumerate(self.octant_features):
            if octant is not None:
                data = np.take(self.data, octant, axis=2).reshape((-1, 3))
                data *= octant_coeffs[i]
                gl.glVertexPointerf(data)
                gl.glNewList(self.gllist_id + i + 1, gl.GL_COMPILE)
                gl.glDrawElementsui(gl.GL_POINTS, indices)
                gl.glEndList()
            else:
                # Create an empty draw list
                gl.glNewList(self.gllist_id + i + 1, gl.GL_COMPILE)
                gl.glEndList()

        self.create_axes_list()
        self._refresh_display_lists = False

    def randomize_features(self):
        '''Randomizes data features displayed using current display mode.'''
        import random
        from pprint import pprint
        ids = list(range(self.data.shape[2]))
        if self.quadrant_mode == 'single':
            features = random_subset(ids, 3)
        elif self.quadrant_mode == 'mirrored':
            features = random_subset(ids, 6)
        else:
            features = [random_subset(ids, 3) for i in range(8)]
        print('New feature IDs:')
        pprint(np.array(features))
        self.set_octant_display_features(features)

    def set_features(self, features, mode='single'):
        from pprint import pprint
        if mode == 'single':
            if len(features) != 3:
                raise Exception(
                    'Expected 3 feature indices for "single" mode.')
        elif mode == 'mirrored':
            if len(features) != 6:
                raise Exception(
                    'Expected 6 feature indices for "mirrored" mode.')
        elif mode == 'independent':
            if len(features) != 8:
                raise Exception('Expected 8 3-tuples of feature indices for'
                                '"independent" mode.')
        else:
            raise Exception('Unrecognized feature mode: %s.' % str(mode))
        print('New feature IDs:')
        pprint(np.array(features))
        self.set_octant_display_features(features)
        self.Refresh()

    def draw_box(self, x0, y0, x1, y1):
        '''Draws a selection box in the 3-D window.
        Coordinates are with respect to the lower left corner of the window.
        '''
        import OpenGL.GL as gl
        gl.glMatrixMode(gl.GL_PROJECTION)
        gl.glLoadIdentity()
        gl.glOrtho(0.0, self.size[0],
                   0.0, self.size[1],
                   -0.01, 10.0)

        gl.glLineStipple(1, 0xF00F)
        gl.glEnable(gl.GL_LINE_STIPPLE)
        gl.glLineWidth(1.0)
        gl.glColor3f(1.0, 1.0, 1.0)
        gl.glBegin(gl.GL_LINE_LOOP)
        gl.glVertex3f(x0, y0, 0.0)
        gl.glVertex3f(x1, y0, 0.0)
        gl.glVertex3f(x1, y1, 0.0)
        gl.glVertex3f(x0, y1, 0.0)
        gl.glEnd()
        gl.glDisable(gl.GL_LINE_STIPPLE)
        gl.glFlush()

        self.resize(*self.size)

    def on_paint(self, event):
        '''Renders the entire scene.'''
        import time
        import OpenGL.GL as gl
        import OpenGL.GLU as glu

        self.canvas.SetCurrent(self.canvas.context)
        if not self.gl_initialized:
            self.initgl()
            self.gl_initialized = True
            self.print_help()
            self.resize(*self.size)

        gl.glMatrixMode(gl.GL_MODELVIEW)
        gl.glLoadIdentity()
        gl.glClear(gl.GL_COLOR_BUFFER_BIT | gl.GL_DEPTH_BUFFER_BIT)

        while len(self._display_commands) > 0:
            self._display_commands.pop(0)()

        if self._refresh_display_lists:
            self.create_display_lists()

        gl.glPushMatrix()
        # camera_pos_rtp is relative to target position. To get the absolute
        # camera position, we need to add the target position.
        camera_pos_xyz = np.array(rtp_to_xyz(*self.camera_pos_rtp)) \
            + self.target_pos
        glu.gluLookAt(
            *(list(camera_pos_xyz) + list(self.target_pos) + self.up))

        if self.show_axes_tf:
            gl.glCallList(self.gllist_id)

        self.draw_data_set()

        gl.glPopMatrix()
        gl.glFlush()

        if self._selection_box is not None:
            self.draw_box(*self._selection_box)

        self.SwapBuffers()
        event.Skip()

    def post_reassign_selection(self, new_class):
        '''Reassigns pixels in selection box during the next rendering loop.
        ARGUMENT:
            `new_class` (int):
                The class to which the pixels in the box will be assigned.
        '''
        if self._selection_box is None:
            msg = 'Bounding box is not selected. Hold SHIFT and click & ' + \
                  'drag with the left\nmouse button to select a region.'
            print(msg)
            return 0
        self.add_display_command(lambda: self.reassign_selection(new_class))
        self.canvas.Refresh()
        return 0

    def reassign_selection(self, new_class):
        '''Reassigns pixels in the selection box to the specified class.
        This method should only be called from the `display` method. Pixels are
        reassigned by identifying each pixel in the 3D display by their unique
        color, then reassigning them. Since pixels can block others in the
        z-buffer, this method iteratively reassigns pixels by removing any
        reassigned pixels from the display list, then reassigning again,
        repeating until there are no more pixels in the selction box.
        '''
        import spectral
        nreassigned_tot = 0
        i = 1
        print('Reassigning points', end=' ')
        while True:
            indices = np.array(self._display_indices)
            classes = np.array(self.classes.ravel()[indices])
            indices = indices[np.where(classes != new_class)]
            ids = self.get_points_in_selection_box(indices=indices)
            cr = self.classes.ravel()
            nreassigned = np.sum(cr[ids] != new_class)
            nreassigned_tot += nreassigned
            cr[ids] = new_class
            new_color = np.zeros(4, 'uint8')
            new_color[:3] = (np.array(self.palette[new_class])
                             * 255).astype('uint8')
            self.colors.reshape((-1, 4))[ids] = new_color
            self.create_display_lists()
            if len(ids) == 0:
                break
#           print 'Pass %d: %d points reassigned to class %d.' \
#                 % (i, nreassigned, new_class)
            print('.', end=' ')
            i += 1
        print('\n%d points were reasssigned to class %d.' \
              % (nreassigned_tot, new_class))
        self._selection_box = None
        if nreassigned_tot > 0 and new_class == self.max_menu_class:
            self.max_menu_class += 1

        if nreassigned_tot > 0:
            from .spypylab import SpyMplEvent
            event = SpyMplEvent('spy_classes_modified')
            event.classes = self.classes
            event.nchanged = nreassigned_tot
            self.callbacks.process('spy_classes_modified', event)

        return nreassigned_tot

    def get_points_in_selection_box(self, **kwargs):
        '''Returns pixel IDs of all points in the current selection box.
        KEYWORD ARGS:
            `indices` (ndarray of ints):
                An alternate set of N-D image pixels to display.

        Pixels are identified by performing a background rendering loop wherein
        each pixel is rendered with a unique color. Then, glReadPixels is used
        to read colors of pixels in the current selection box.
        '''
        import OpenGL.GL as gl
        indices = kwargs.get('indices', None)
        point_size_temp = self.point_size
        self.point_size = kwargs.get('point_size', 1)

        xsize = self._selection_box[2] - self._selection_box[0] + 1
        ysize = self._selection_box[3] - self._selection_box[1] + 1
        ids = np.zeros(xsize * ysize, int)

        self.create_display_lists(0, indices=indices)
        self.render_rgb_indexed_colors()
        gl.glPixelStorei(gl.GL_UNPACK_ALIGNMENT, 1)
        pixels = gl.glReadPixelsub(self._selection_box[0],
                                   self._selection_box[1],
                                   xsize, ysize, gl.GL_RGBA)
        pixels = np.frombuffer(pixels, dtype=np.uint8).reshape((ysize, xsize, 4))
        for i in range(4):
            component = pixels[:, :, i].reshape((xsize * ysize,)) \
                & self._rgba_masks[i]
            shift = (sum(self._rgba_bits[0:i]) - self._low_bits[i])
            if shift > 0:
                ids += component.astype(int) << shift
            else:
                ids += component.astype(int) >> (-shift)

        points = ids[ids > 0]

        self.point_size = point_size_temp
        gl.glMatrixMode(gl.GL_MODELVIEW)
        gl.glLoadIdentity()
        gl.glClear(gl.GL_COLOR_BUFFER_BIT | gl.GL_DEPTH_BUFFER_BIT)
        self._refresh_display_lists = True

        return points

    def get_pixel_info(self, x, y, **kwargs):
        '''Prints row/col of the pixel at the given raster position.
        ARGUMENTS:
            `x`, `y`: (int):
                The pixel's coordinates relative to the lower left corner.
        '''
        self._selection_box = (x, y, x, y)
        ids = self.get_points_in_selection_box(point_size=self.point_size)
        for id in ids:
            if id > 0:
                rc = self.index_to_image_row_col(id)
                print('Pixel %d %s has class %s.' % (id, rc, self.classes[rc]))
        return

    def render_rgb_indexed_colors(self, **kwargs):
        '''Draws scene in the background buffer to extract mouse click info'''
        import OpenGL.GL as gl
        import OpenGL.GLU as glu
        gl.glMatrixMode(gl.GL_MODELVIEW)
        gl.glLoadIdentity()
        gl.glClear(gl.GL_COLOR_BUFFER_BIT | gl.GL_DEPTH_BUFFER_BIT)

        # camera_pos_rtp is relative to the target position. To get the
        # absolute camera position, we need to add the target position.
        gl.glPushMatrix()
        camera_pos_xyz = np.array(rtp_to_xyz(*self.camera_pos_rtp)) \
            + self.target_pos
        glu.gluLookAt(
            *(list(camera_pos_xyz) + list(self.target_pos) + self.up))
        self.draw_data_set()
        gl.glPopMatrix()
        gl.glFlush()

    def index_to_image_row_col(self, index):
        '''Converts the unraveled pixel ID to row/col of the N-D image.'''
        rowcol = (index / self.data.shape[1], index % self.data.shape[1])
        return rowcol

    def draw_data_set(self):
        '''Draws the N-D data set in the scene.'''
        import OpenGL.GL as gl
        for i in range(1, 9):
            gl.glCallList(self.gllist_id + i)

    def create_axes_list(self):
        '''Creates display lists to render unit length x,y,z axes.'''
        import OpenGL.GL as gl
        gl.glNewList(self.gllist_id, gl.GL_COMPILE)
        gl.glBegin(gl.GL_LINES)
        gl.glColor3f(1.0, 0.0, 0.0)
        gl.glVertex3f(0.0, 0.0, 0.0)
        gl.glVertex3f(1.0, 0.0, 0.0)
        gl.glColor3f(0.0, 1.0, 0.0)
        gl.glVertex3f(0.0, 0.0, 0.0)
        gl.glVertex3f(0.0, 1.0, 0.0)
        gl.glColor3f(-.0, 0.0, 1.0)
        gl.glVertex3f(0.0, 0.0, 0.0)
        gl.glVertex3f(0.0, 0.0, 1.0)

        gl.glColor3f(1.0, 1.0, 1.0)
        gl.glVertex3f(0.0, 0.0, 0.0)
        gl.glVertex3f(-1.0, 0.0, 0.0)
        gl.glVertex3f(0.0, 0.0, 0.0)
        gl.glVertex3f(0.0, -1.0, 0.0)
        gl.glVertex3f(0.0, 0.0, 0.0)
        gl.glVertex3f(0.0, 0.0, -1.0)
        gl.glEnd()

        if self._have_glut:
            try:
                import OpenGL.GLUT as glut
                if bool(glut.glutBitmapCharacter):
                    gl.glRasterPos3f(1.05, 0.0, 0.0)
                    glut.glutBitmapCharacter(glut.GLUT_BITMAP_HELVETICA_18,
                                             ord('x'))
                    gl.glRasterPos3f(0.0, 1.05, 0.0)
                    glut.glutBitmapCharacter(glut.GLUT_BITMAP_HELVETICA_18,
                                             ord('y'))
                    gl.glRasterPos3f(0.0, 0.0, 1.05)
                    glut.glutBitmapCharacter(glut.GLUT_BITMAP_HELVETICA_18,
                                             ord('z'))
            except:
                pass
        gl.glEndList()

    def GetGLExtents(self):
        """Get the extents of the OpenGL canvas."""
        return

    def SwapBuffers(self):
        """Swap the OpenGL buffers."""
        self.canvas.SwapBuffers()

    def on_erase_background(self, event):
        """Process the erase background event."""
        pass  # Do nothing, to avoid flashing on MSWin

    def initgl(self):
        '''App-specific initialization for after GLUT has been initialized.'''
        import OpenGL.GL as gl
        self.gllist_id = gl.glGenLists(9)
        gl.glEnableClientState(gl.GL_VERTEX_ARRAY)
        gl.glEnableClientState(gl.GL_COLOR_ARRAY)
        gl.glDisable(gl.GL_LIGHTING)
        gl.glDisable(gl.GL_TEXTURE_2D)
        gl.glDisable(gl.GL_FOG)
        gl.glDisable(gl.GL_COLOR_MATERIAL)
        gl.glEnable(gl.GL_DEPTH_TEST)
        gl.glShadeModel(gl.GL_FLAT)
        self.set_data(self.data, classes=self.classes, features=self.features)

        try:
            import OpenGL.GLUT as glut
            glut.glutInit()
            self._have_glut = True
        except:
            pass

    def on_resize(self, event):
        '''Process the resize event.'''

        # For wx versions 2.9.x, GLCanvas.GetContext() always returns None,
        # whereas 2.8.x will return the context so test for both versions.

        if wx.VERSION >= (2, 9) or self.canvas.GetContext():
            self.canvas.SetCurrent(self.canvas.context)
            # Make sure the frame is shown before calling SetCurrent.
            self.Show()
            size = event.GetSize()
            self.resize(size.width, size.height)
            self.canvas.Refresh(False)
        event.Skip()

    def resize(self, width, height):
        """Reshape the OpenGL viewport based on dimensions of the window."""
        import OpenGL.GL as gl
        import OpenGL.GLU as glu
        self.size = (width, height)
        gl.glViewport(0, 0, width, height)
        gl.glMatrixMode(gl.GL_PROJECTION)
        gl.glLoadIdentity()
        glu.gluPerspective(self.fovy, float(width) / height,
                           self.znear, self.zfar)

        gl.glMatrixMode(gl.GL_MODELVIEW)
        gl.glLoadIdentity()

    def on_char(self, event):
        '''Callback function for when a keyboard button is pressed.'''
        key = chr(event.GetKeyCode())

        # See `print_help` method for explanation of keybinds.
        if key == 'a':
            self.show_axes_tf = not self.show_axes_tf
        elif key == 'c':
            self.view_class_image()
        elif key == 'd':
            if self.data.shape[2] < 6:
                print('Only single-quadrant mode is supported for %d features.' % \
                      self.data.shape[2])
                return
            if self.quadrant_mode == 'single':
                self.quadrant_mode = 'mirrored'
            elif self.quadrant_mode == 'mirrored':
                self.quadrant_mode = 'independent'
            else:
                self.quadrant_mode = 'single'
            print('Setting quadrant display mode to %s.' % self.quadrant_mode)
            self.randomize_features()
        elif key == 'f':
            self.randomize_features()
        elif key == 'h':
            self.print_help()
        elif key == 'm':
            self.mouse_panning = not self.mouse_panning
        elif key == 'p':
            self.point_size += 1
            self._refresh_display_lists = True
        elif key == 'P':
            self.point_size = max(self.point_size - 1, 1.0)
            self._refresh_display_lists = True
        elif key == 'q':
            self.on_event_close()
            self.Close(True)
        elif key == 'r':
            self.reset_view_geometry()
        elif key == 'u':
            self._show_unassigned = not self._show_unassigned
            print('SHOW UNASSIGNED =', self._show_unassigned)
            self._refresh_display_lists = True
 
        self.canvas.Refresh()

    def update_window_title(self):
        '''Prints current file name and current point color to window title.'''
        s = 'SPy N-D Data Set'
        glutSetWindowTitle(s)

    def get_proxy(self):
        '''Returns a proxy object to access data from the window.'''
        return NDWindowProxy(self)

    def view_class_image(self, *args, **kwargs):
        '''Opens a dynamic raster image of class values.

        The class IDs displayed are those currently associated with the ND
        window. `args` and `kwargs` are additional arguments passed on to the
        `ImageView` constructor. Return value is the ImageView object.
        '''
        from .spypylab import ImageView, MplCallback
        view = ImageView(classes=self.classes, *args, **kwargs)
        view.callbacks_common = self.callbacks
        view.show()
        return view

    def print_help(self):
        '''Prints a list of accepted keyboard/mouse inputs.'''
        import os
        print('''Mouse functions:
---------------
Left-click & drag       -->     Rotate viewing geometry (or pan)
CTRL+Left-click & drag  -->     Zoom viewing geometry
CTRL+SHIFT+Left-click   -->     Print image row/col and class of selected pixel
SHIFT+Left-click & drag -->     Define selection box in the window
Right-click             -->     Open GLUT menu for pixel reassignment

Keyboard functions:
-------------------
a       -->     Toggle axis display
c       -->     View dynamic raster image of class values
d       -->     Cycle display mode between single-quadrant, mirrored octants,
                and independent octants (display will not change until features
                are randomzed again)
f       -->     Randomize features displayed
h       -->     Print this help message
m       -->     Toggle mouse function between rotate/zoom and pan modes
p/P     -->     Increase/Decrease the size of displayed points
q       -->     Exit the application
r       -->     Reset viewing geometry
u       -->     Toggle display of unassigned points (points with class == 0)
''')
Пример #9
0
class Cursor:
    """A cursor for selecting artists on a matplotlib figure.
    """

    _keep_alive = WeakKeyDictionary()

    def __init__(self,
                 artists,
                 *,
                 multiple=False,
                 highlight=False,
                 hover=False,
                 bindings=default_bindings):
        """Construct a cursor.

        Parameters
        ----------

        artists : List[Artist]
            A list of artists that can be selected by this cursor.

        multiple : bool
            Whether multiple artists can be "on" at the same time (defaults to
            False).

        highlight : bool
            Whether to also highlight the selected artist.  If so,
            "highlighter" artists will be placed as the first item in the
            :attr:`extras` attribute of the `Selection`.

        bindings : dict
            A mapping of button and keybindings to actions.  Valid entries are:

            =================== ===============================================
            'select'            mouse button to select an artist (default: 1)
            'deselect'          mouse button to deselect an artist (default: 3)
            'left'              move to the previous point in the selected
                                path, or to the left in the selected image
                                (default: shift+left)
            'right'             move to the next point in the selected path, or
                                to the right in the selected image
                                (default: shift+right)
            'up'                move up in the selected image
                                (default: shift+up)
            'down'              move down in the selected image
                                (default: shift+down)
            'toggle_visibility' toggle visibility of all cursors (default: d)
            'toggle_enabled'    toggle whether the cursor is active
                                (default: t)
            =================== ===============================================

        hover : bool
            Whether to select artists upon hovering instead of by clicking.
        """

        artists = list(artists)
        # Be careful with GC.
        self._artists = [weakref.ref(artist) for artist in artists]

        for artist in artists:
            type(self)._keep_alive.setdefault(artist, []).append(self)

        self._multiple = multiple
        self._highlight = highlight

        self._axes = {artist.axes for artist in artists}
        self._enabled = True
        self._selections = []
        self._callbacks = CallbackRegistry()

        connect_pairs = [("key_press_event", self._on_key_press)]
        if hover:
            if multiple:
                raise ValueError("`hover` and `multiple` are incompatible")
            connect_pairs += [
                ("motion_notify_event", self._on_select_button_press)]
        else:
            connect_pairs += [
                ("button_press_event", self._on_button_press)]
        self._disconnect_cids = [
            partial(canvas.mpl_disconnect, canvas.mpl_connect(*pair))
            for pair in connect_pairs
            for canvas in {artist.figure.canvas for artist in artists}]

        bindings = {**default_bindings, **bindings}
        if set(bindings) != set(default_bindings):
            raise ValueError("Unknown bindings")
        actually_bound = {k: v for k, v in bindings.items() if v is not None}
        if len(set(actually_bound.values())) != len(actually_bound):
            raise ValueError("Duplicate bindings")
        self._bindings = bindings

    @property
    def enabled(self):
        """Whether clicks are registered for picking and unpicking events.
        """
        return self._enabled

    @enabled.setter
    def enabled(self, value):
        self._enabled = value

    @property
    def artists(self):
        """The tuple of selectable artists.
        """
        return tuple(filter(None, (ref() for ref in self._artists)))

    @property
    def selections(self):
        """The tuple of current `Selection`\\s.
        """
        return tuple(self._selections)

    def add_selection(self, pi):
        """Create an annotation for a `Selection` and register it.

        Returns a new `Selection`, that has been registered by the `Cursor`,
        with the added annotation set in the :attr:`annotation` field and, if
        applicable, the highlighting artist in the :attr:`extras` field.

        Emits the ``"add"`` event with the new `Selection` as argument.
        """
        # pi: "pick_info", i.e. an incomplete selection.
        ann = pi.artist.axes.annotate(
            _pick_info.get_ann_text(*pi),
            xy=pi.target,
            **default_annotation_kwargs)
        ann.draggable(use_blit=True)
        extras = []
        if self._highlight:
            extras.append(self.add_highlight(pi.artist))
        if not self._multiple:
            while self._selections:
                self._remove_selection(self._selections[-1])
        sel = pi._replace(annotation=ann, extras=extras)
        self._selections.append(sel)
        self._callbacks.process("add", sel)
        sel.artist.figure.canvas.draw_idle()
        return sel

    def add_highlight(self, artist):
        """Create, add and return a highlighting artist.

        It is up to the caller to register the artist with the proper
        `Selection` in order to ensure cleanup upon deselection.
        """
        hl = copy.copy(artist)
        hl.set(**default_highlight_kwargs)
        artist.axes.add_artist(hl)
        return hl

    def connect(self, event, func=None):
        """Connect a callback to a `Cursor` event; return the callback id.

        Two classes of event can be emitted, both with a `Selection` as single
        argument:

            - ``"add"`` when a `Selection` is added, and
            - ``"remove"`` when a `Selection` is removed.

        The callback registry relies on :mod:`matplotlib`'s implementation; in
        particular, only weak references are kept for bound methods.

        This method is can also be used as a decorator::

            @cursor.connect("add")
            def on_add(sel):
                ...
        """
        if event not in ["add", "remove"]:
            raise ValueError("Invalid cursor event: {}".format(event))
        if func is None:
            return partial(self.connect, event)
        return self._callbacks.connect(event, func)

    def disconnect(self, cid):
        """Disconnect a previously connected callback id.
        """
        self._callbacks.disconnect(cid)

    def remove(self):
        """Remove all `Selection`\\s and disconnect all callbacks.
        """
        for disconnect_cid in self._disconnect_cids:
            disconnect_cid()
        while self._selections:
            self._remove_selection(self._selections[-1])

    def _on_button_press(self, event):
        if event.button == self._bindings["select"]:
            self._on_select_button_press(event)
        if event.button == self._bindings["deselect"]:
            self._on_deselect_button_press(event)

    def _filter_mouse_event(self, event):
        # Accept the event iff we are enabled, and either
        #   - no other widget is active, and this is not the second click of a
        #     double click (to prevent double selection), or
        #   - another widget is active, and this is a double click (to bypass
        #     the widget lock).
        return (self.enabled
                and event.canvas.widgetlock.locked() == event.dblclick)

    def _on_select_button_press(self, event):
        if not self._filter_mouse_event(event):
            return
        # Work around lack of support for twinned axes.
        per_axes_event = {ax: _reassigned_axes_event(event, ax)
                          for ax in self._axes}
        pis = []
        for artist in self.artists:
            if (artist.axes is None  # Removed or figure-level artist.
                    or event.canvas is not artist.figure.canvas
                    or not artist.axes.contains(event)[0]):  # Cropped by axes.
                continue
            pi = _pick_info.compute_pick(artist, per_axes_event[artist.axes])
            if pi:
                pis.append(pi)
        if not pis:
            return
        self.add_selection(min(pis, key=lambda pi: pi.dist))

    def _on_deselect_button_press(self, event):
        if not self._filter_mouse_event(event):
            return
        for sel in self._selections:
            ann = sel.annotation
            if event.canvas is not ann.figure.canvas:
                continue
            contained, _ = ann.contains(event)
            if contained:
                self._remove_selection(sel)

    def _on_key_press(self, event):
        if event.key == self._bindings["toggle_enabled"]:
            self.enabled = not self.enabled
        elif event.key == self._bindings["toggle_visibility"]:
            for sel in self._selections:
                sel.annotation.set_visible(not sel.annotation.get_visible())
                sel.annotation.figure.canvas.draw_idle()
        if self._selections:
            sel = self._selections[-1]
        else:
            return
        for key in ["left", "right", "up", "down"]:
            if event.key == self._bindings[key]:
                self._remove_selection(sel)
                self.add_selection(_pick_info.move(*sel, key=key))
                break

    def _remove_selection(self, sel):
        self._selections.remove(sel)
        # Work around matplotlib/matplotlib#6785.
        draggable = sel.annotation._draggable
        try:
            draggable.disconnect()
            sel.annotation.figure.canvas.mpl_disconnect(
                sel.annotation._draggable._c1)
        except AttributeError:
            pass
        # (end of workaround).
        # <artist>.figure will be unset so we save them first.
        figures = {artist.figure for artist in [sel.annotation, *sel.extras]}
        # ValueError is raised if the artist has already been removed.
        with suppress(ValueError):
            sel.annotation.remove()
        for artist in sel.extras:
            with suppress(ValueError):
                artist.remove()
        self._callbacks.process("remove", sel)
        for figure in figures:
            figure.canvas.draw_idle()