class Highlighter(object):
    def __init__(self, ax, x, y):
        self.ax = ax
        self.canvas = ax.figure.canvas
        self.x, self.y = x, y
        self.mask = np.zeros(x.shape, dtype=bool)

        self._highlight = ax.scatter([], [], s=200, color='yellow', zorder=10)

        self.selector = RectangleSelector(ax, self, useblit=False)
        if self.selector.active:
            self.selector.update()
        plt.show()

    def __call__(self, event1, event2):
        self.mask |= self.inside(event1, event2)
        xy = np.column_stack([self.x[self.mask], self.y[self.mask]])
        self._highlight.set_offsets(xy)
        self.canvas.draw()

    def inside(self, event1, event2):
        """Returns a boolean mask of the points inside the rectangle defined by
        event1 and event2."""
        # Note: Could use points_inside_poly, as well
        x0, x1 = sorted([event1.xdata, event2.xdata])
        y0, y1 = sorted([event1.ydata, event2.ydata])
        mask = ((self.x > x0) & (self.x < x1) & (self.y > y0) & (self.y < y1))
        return mask
class AreaSelector(Frozen):
    def __init__(self, ax, line_select_callback):
        self.ax = ax
        self.rs = RectangleSelector(
            ax,
            line_select_callback,
            drawtype='box',
            useblit=False,
            button=[1, 3],  # don't use middle button
            minspanx=0,
            minspany=0,
            spancoords='pixels',
            interactive=True)

    def __call__(self, event):
        self.rs.update()
        if self.ax == event.inaxes:
            if event.key in ['Q', 'q']:
                self.rs.to_draw.set_visible(False)
                self.rs.set_active(False)
            if event.key in ['A', 'a']:
                self.rs.to_draw.set_visible(True)
                self.rs.set_active(True)

        return  #__call__
Beispiel #3
0
class FrameView(KeyMapManager, MouseManager):
    def __init__(self,
                 ims,
                 extent='auto',
                 patch_maker=None,
                 in_range='full',
                 origin='upper',
                 time_unit='',
                 time_loc='upper left',
                 frames=None,
                 times=None,
                 framemaster=None,
                 pixel_cmap='gray',
                 font_dict={'size': 12},
                 rectangle=False):
        """
        Simple event driven frameview bound to right and left key presses.
        Parameters
        ----------
        ims : ndimage (frame, rows, columns)
        extent : ['auto', (left, right, bottom, top), pd.DataFrame indexed by frames]
            The coordinates of the image
        patch_maker : callable
            Accepts a frame argument and returns a list of patches to add
        in_range : ['full', 'auto']
        interval : (int, str)
            Number of units per frame increment. 'E.g.' (2, 'sec')
        time_unit : 'sec', 'min', ...
        time_loc : ['lower left', ..., 'upper right']
            Coordinates of time location. Set to None to have no time stamp.
        pixel_cmap : str
            colormap for pixel data.
        rectangle : bool
            Whether to hook in a rectangle drawing widget for potential cropping functions.
        Example
        -------
        >>> import numpy as np
        >>> ims = np.random.randint(0, 255, (10, 256, 256), dtype=np.uint8)
        >>> fig, ax = plt.subplots()
        >>> fv = FrameView(ims)
        >>> plt.show()
        """
        self.ims = ims
        if frames is None:
            self.frames = deque(range(len(ims)))
        else:
            self.frames = deque(frames)
        self._frames = np.array(self.frames)
        assert len(self.frames) == len(ims)
        self.ims = {k: i for k, i in zip(self.frames, ims)}
        if framemaster is None:
            self.framemaster = None
            self.stepsize = int(len(self._frames) // 20)
        else:
            self.framemaster = deque(framemaster)
            self.stepsize = len(framemaster) // 20
        self.frame = self.frames[0]
        if times is None:
            self.times = deque(list(self.frames))
        else:
            self.times = deque(times)
        self.time_unit = time_unit
        self.fd = font_dict
        # We want a patch generator since it is low on memory (a tad higher in processing)
        # and we want fresh patches to avoid errors from using patches on different axes
        # This is relevant since self.mp4 requires "Agg" backend associated canvas, which we
        # don't want to interfere with pyplot's global state.
        self.patch_maker = patch_maker
        self.ax = plt.gca()
        self.fig = plt.gcf()
        if hasattr(self.ax, 'get_subplotspec'):
            # If using a gridspec, later use of fig.subplots_adjust is ignored
            ss = self.ax.get_subplotspec()
            gs = ss.get_gridspec()
            gs.update(left=0, right=1, bottom=0, top=1)
        else:
            self.fig.subplots_adjust(left=0, right=1, bottom=0, top=1)
        self.ax.axis('off')

        self.start = 0
        self.stop = len(ims) - 1
        if isinstance(extent, str) and extent == 'auto':
            try:
                shape = ims.shape
            except AttributeError:  # Maybe a PIMS frame object
                shape = (len(ims), *ims.frame_shape)
            self.extent = (0, shape[2], shape[1], 0)
            self.dfextent = None
        elif isinstance(extent, pd.DataFrame):
            self.dfextent = extent
            self.extent = self.dfextent.iloc[0]
        elif isinstance(extent, (tuple, list)):
            self.dfextent = None
            self.extent = extent
        self.pmap = pixel_cmap
        # Display ranges
        self.in_range = in_range
        if in_range == 'full':
            mn, mx = (int(ims.min()), int(ims.max()))
            self.vmin, self.vmax = mn, mx
        elif in_range == 'auto':
            self.vmin, self.vmax = None, None
        self.imax = self.ax.imshow(ims[0],
                                   self.pmap,
                                   vmin=self.vmin,
                                   vmax=self.vmax,
                                   interpolation='none',
                                   extent=self.extent,
                                   origin=origin)
        self.ax.axis('off')
        self.fig.subplots_adjust(left=0.1, right=0.95, top=0.95, bottom=0.1)
        self.time_loc = time_loc
        self.time_stamp = None
        self.set_time()
        self.add_patches()
        KeyMapManager.__init__(self, self.fig)
        MouseManager.__init__(self, self.fig)
        # Add Key Functions
        self.add_key_callback('right', 'Increment frame', self.handle_right)
        self.add_key_callback('left', 'Decrement frame', self.handle_left)
        # Add Scroll Functions
        self.add_mouse_callback('scroll', self.handle_scroll)
        # Add Rectangl e Selector
        if rectangle:
            self.rs = RectangleSelector(
                self.ax,
                self.line_select_callback,
                drawtype='box',
                useblit=True,
                button=[1, 3],  # don't use middle button
                minspanx=10,
                minspany=10,
                spancoords='pixels',
                interactive=True)
        else:
            self.rs = None
        # Other
        self.frame_callbacks = []
        self.updaters = []

    def line_select_callback(self, eclick, erelease):
        'eclick and erelease are the press and release events'
        x1, y1 = eclick.xdata, eclick.ydata
        x2, y2 = erelease.xdata, erelease.ydata
        print("(%3.2f, %3.2f) --> (%3.2f, %3.2f)" % (x1, y1, x2, y2))
        print(" The buttons you used were: %s %s" %
              (eclick.button, erelease.button))
        if len(self.ax.lines) > 2: [l.remove() for l in self.ax.lines[:-2]]

    def add_frame_callback(self, callback):
        """Functions of the from foo(image, ax) -> ax"""
        if callback not in self.frame_callbacks:
            self.frame_callbacks.append(callback)

    def rotate(self, step):
        if self.framemaster:
            self.framemaster.rotate(step)
            fr = self.framemaster[0]
            s = int(step / abs(step))
            if fr in self.frames:
                it = 0
                while fr != self.frames[0]:
                    self.frames.rotate(s)
                    self.times.rotate(s)
                    it += 1
                    if it > len(self.frames):
                        raise ValueError("fr somehow in self.frames")
                self.frame = self.frames[0]
            else:
                # display most recent frame
                df = fr - self._frames
                if all(df < 0):
                    mindf = (-df).min()
                    minf = self._frames[df == mindf]
                else:
                    mindf = df[df > 0].min()
                    minf = self._frames[df == mindf]
                it = 0
                while minf != self.frames[0]:
                    self.frames.rotate(1)
                    self.times.rotate(1)
                    it += 1
                    if it > len(self.frames):
                        raise ValueError("fr somehow in self.frames")
                self.frame = self.frames[0]
        else:
            self.frames.rotate(step)
            self.times.rotate(step)
            self.frame = self.frames[0]

    def handle_scroll(self, event):
        step = int(event.step * self.stepsize)
        self.rotate(step)
        self.update(event)

    def handle_right(self, event):
        self.rotate(-1)
        self.update(event)

    def handle_left(self, event):
        self.rotate(1)
        self.update(event)

    def update(self, event):
        # Ordering here is important. We want to draw the pixels first AND THEN the patches
        # otherwise the pixels will cover that patch.
        # Patches first: points, etc.
        for p in reversed(self.ax.patches):
            p.remove()
        for a in reversed(self.ax.artists):
            a.remove()
        # Artists: timestamps, labels, etc.
        if self.in_range == 'auto':
            self.show_image()
        else:
            self.set_image()  # recalculate display values
        self.set_extent()  # If the extent changes each frame
        self.set_time()
        self.add_patches()
        if self.frame_callbacks:
            for cb in self.frame_callbacks:
                cb(self.ims[self.frame - self.start], self.ax)
        if self.updaters:
            for ud in self.updaters:
                ud()
        self.ax.figure.canvas.draw()
        if self.rs: self.rs.update()

    def show_image(self):
        image = self.ims[self.frame]
        self.imax = self.ax.imshow(image,
                                   self.pmap,
                                   vmin=self.vmin,
                                   vmax=self.vmax,
                                   interpolation='none',
                                   extent=self.extent)

    def set_image(self):
        image = self.ims[self.frame]
        self.imax.set_data(image)

    def set_extent(self):
        if self.dfextent is not None:
            new_extent = self.dfextent.loc[self.frame].values
            self.extent = new_extent
            self.imax.set_extent(new_extent)
        else:
            pass

    def add_patches(self):
        if self.patch_maker is not None:
            ps = self.patch_maker(self.frame)
            if ps:
                for p in ps:
                    self.ax.add_patch(p)

    def set_time(self):
        if self.time_loc is not None:
            if self.times is None:
                label = f'frame {self.frame}'
            else:
                label = f'{self.times[0]:.02f} {self.time_unit}\nframe: {self.frame}'
            self.time_stamp = self.ax.set_title(label)
class ImageView(object):
    '''Class to manage events and data associated with image raster views.

    In most cases, it is more convenient to simply call :func:`~spectral.graphics.spypylab.imshow`,
    which creates, displays, and returns an :class:`ImageView` object. Creating
    an :class:`ImageView` object directly (or creating an instance of a subclass)
    enables additional customization of the image display (e.g., overriding
    default event handlers). If the object is created directly, call the
    :meth:`show` method to display the image. The underlying image display
    functionality is implemented via :func:`matplotlib.pyplot.imshow`.
    '''
    selector_rectprops = dict(facecolor='red',
                              edgecolor='black',
                              alpha=0.5,
                              fill=True)
    selector_lineprops = dict(color='black',
                              linestyle='-',
                              linewidth=2,
                              alpha=0.5)

    def __init__(self,
                 data=None,
                 bands=None,
                 classes=None,
                 source=None,
                 **kwargs):
        '''
        Arguments:

            `data` (ndarray or :class:`SpyFile`):

                The source of RGB bands to be displayed. with shape (R, C, B).
                If the shape is (R, C, 3), the last dimension is assumed to
                provide the red, green, and blue bands (unless the `bands`
                argument is provided). If :math:`B > 3` and `bands` is not
                provided, the first, middle, and last band will be used.

            `bands` (triplet of integers):

                Specifies which bands in `data` should be displayed as red,
                green, and blue, respectively.

            `classes` (ndarray of integers):

                An array of integer-valued class labels with shape (R, C). If
                the `data` argument is provided, the shape must match the first
                two dimensions of `data`.

            `source` (ndarray or :class:`SpyFile`):

                The source of spectral data associated with the image display.
                This optional argument is used to access spectral data (e.g., to
                generate a spectrum plot when a user double-clicks on the image
                display.

        Keyword arguments:

            Any keyword that can be provided to :func:`~spectral.graphics.graphics.get_rgb`
            or :func:`matplotlib.imshow`.
        '''

        import spectral
        from spectral import settings
        self.is_shown = False
        self.imshow_data_kwargs = {'cmap': settings.imshow_float_cmap}
        self.rgb_kwargs = {}
        self.imshow_class_kwargs = {'zorder': 1}

        self.data = data
        self.data_rgb = None
        self.data_rgb_meta = {}
        self.classes = None
        self.class_rgb = None
        self.source = None
        self.bands = bands
        self.data_axes = None
        self.class_axes = None
        self.axes = None
        self._image_shape = None
        self.display_mode = None
        self._interpolation = None
        self.selection = None
        self.interpolation = kwargs.get('interpolation',
                                        settings.imshow_interpolation)

        if data is not None:
            self.set_data(data, bands, **kwargs)
        if classes is not None:
            self.set_classes(classes, **kwargs)
        if source is not None:
            self.set_source(source)

        self.class_colors = spectral.spy_colors

        self.spectrum_plot_fig_id = None
        self.parent = None
        self.selector = None
        self._on_parent_click_cid = None
        self._class_alpha = settings.imshow_class_alpha

        # Callbacks for events associated specifically with this window.
        self.callbacks = None

        # A sharable callback registry for related windows. If this
        # CallbackRegistry is set prior to calling ImageView.show (e.g., by
        # setting it equal to the `callbacks_common` member of another
        # ImageView object), then the registry will be shared. Otherwise, a new
        # callback registry will be created for this ImageView.
        self.callbacks_common = None

        check_disable_mpl_callbacks()

    def set_data(self, data, bands=None, **kwargs):
        '''Sets the data to be shown in the RGB channels.
        
        Arguments:

            `data` (ndarray or SpyImage):

                If `data` has more than 3 bands, the `bands` argument can be
                used to specify which 3 bands to display. `data` will be
                passed to `get_rgb` prior to display.

            `bands` (3-tuple of int):

                Indices of the 3 bands to display from `data`.

        Keyword Arguments:

            Any valid keyword for `get_rgb` or `matplotlib.imshow` can be
            given.
        '''
        from .graphics import _get_rgb_kwargs

        self.data = data
        self.bands = bands

        rgb_kwargs = {}
        for k in _get_rgb_kwargs:
            if k in kwargs:
                rgb_kwargs[k] = kwargs.pop(k)
        self.set_rgb_options(**rgb_kwargs)

        self._update_data_rgb()

        if self._image_shape is None:
            self._image_shape = data.shape[:2]
        elif data.shape[:2] != self._image_shape:
            raise ValueError('Image shape is inconsistent with previously ' \
                             'set data.')
        self.imshow_data_kwargs.update(kwargs)
        if 'interpolation' in self.imshow_data_kwargs:
            self.interpolation = self.imshow_data_kwargs['interpolation']
            self.imshow_data_kwargs.pop('interpolation')

        if len(kwargs) > 0 and self.is_shown:
            msg = 'Keyword args to set_data only have an effect if ' \
              'given before the image is shown.'
            warnings.warn(UserWarning(msg))
        if self.is_shown:
            self.refresh()

    def set_rgb_options(self, **kwargs):
        '''Sets parameters affecting RGB display of data.

        Accepts any keyword supported by :func:`~spectral.graphics.graphics.get_rgb`.
        '''
        from .graphics import _get_rgb_kwargs

        for k in kwargs:
            if k not in _get_rgb_kwargs:
                raise ValueError('Unexpected keyword: {0}'.format(k))
        self.rgb_kwargs = kwargs.copy()
        if self.is_shown:
            self._update_data_rgb()
            self.refresh()

    def _update_data_rgb(self):
        '''Regenerates the RGB values for display.'''
        from .graphics import get_rgb_meta

        (self.data_rgb, self.data_rgb_meta) = \
          get_rgb_meta(self.data, self.bands, **self.rgb_kwargs)

        # If it is a gray-scale image, only keep the first RGB component so
        # matplotlib imshow's cmap can still be used.
        if self.data_rgb_meta['mode'] == 'monochrome' and \
           self.data_rgb.ndim ==3:
            (self.bands is not None and len(self.bands) == 1)

    def set_classes(self, classes, colors=None, **kwargs):
        '''Sets the array of class values associated with the image data.

        Arguments:

            `classes` (ndarray of int):

                `classes` must be an integer-valued array with the same
                number rows and columns as the display data (if set).

            `colors`: (array or 3-tuples):

                Color triplets (with values in the range [0, 255]) that
                define the colors to be associatd with the integer indices
                in `classes`.

        Keyword Arguments:

            Any valid keyword for `matplotlib.imshow` can be provided.
        '''
        from .graphics import _get_rgb_kwargs
        self.classes = classes
        if classes is None:
            return
        if self._image_shape is None:
            self._image_shape = classes.shape[:2]
        elif classes.shape[:2] != self._image_shape:
            raise ValueError('Class data shape is inconsistent with ' \
                             'previously set data.')
        if colors is not None:
            self.class_colors = colors

        kwargs = dict([item for item in list(kwargs.items()) if item[0] not in \
                       _get_rgb_kwargs])
        self.imshow_class_kwargs.update(kwargs)

        if 'interpolation' in self.imshow_class_kwargs:
            self.interpolation = self.imshow_class_kwargs['interpolation']
            self.imshow_class_kwargs.pop('interpolation')

        if len(kwargs) > 0 and self.is_shown:
            msg = 'Keyword args to set_classes only have an effect if ' \
              'given before the image is shown.'
            warnings.warn(UserWarning(msg))
        if self.is_shown:
            self.refresh()

    def set_source(self, source):
        '''Sets the image data source (used for accessing spectral data).

        Arguments:

            `source` (ndarray or :class:`SpyFile`):

                The source for spectral data associated with the view.
        '''
        self.source = source

    def show(self, mode=None, fignum=None):
        '''Renders the image data.

        Arguments:

            `mode` (str):

                Must be one of:

                    "data":          Show the data RGB

                    "classes":       Shows indexed color for `classes`

                    "overlay":       Shows class colors overlaid on data RGB.

                If `mode` is not provided, a mode will be automatically
                selected, based on the data set in the ImageView.

            `fignum` (int):

                Figure number of the matplotlib figure in which to display
                the ImageView. If not provided, a new figure will be created.
        '''
        import matplotlib.pyplot as plt
        from spectral import settings

        if self.is_shown:
            msg = 'ImageView.show should only be called once.'
            warnings.warn(UserWarning(msg))
            return

        set_mpl_interactive()

        kwargs = {}
        if fignum is not None:
            kwargs['num'] = fignum
        if settings.imshow_figure_size is not None:
            kwargs['figsize'] = settings.imshow_figure_size
        plt.figure(**kwargs)

        if self.data_rgb is not None:
            self.show_data()
        if self.classes is not None:
            self.show_classes()

        if mode is None:
            self._guess_mode()
        else:
            self.set_display_mode(mode)

        self.axes.format_coord = self.format_coord

        self.init_callbacks()
        self.is_shown = True

    def init_callbacks(self):
        '''Creates the object's callback registry and default callbacks.'''
        from spectral import settings
        from matplotlib.cbook import CallbackRegistry

        self.callbacks = CallbackRegistry()

        # callbacks_common may have been set to a shared external registry
        # (e.g., to the callbacks_common member of another ImageView object). So
        # don't create it if it has already been set.
        if self.callbacks_common is None:
            self.callbacks_common = CallbackRegistry()

        # Keyboard callback
        self.cb_mouse = ImageViewMouseHandler(self)
        self.cb_mouse.connect()

        # Mouse callback
        self.cb_keyboard = ImageViewKeyboardHandler(self)
        self.cb_keyboard.connect()

        # Class update event callback
        def updater(*args, **kwargs):
            if self.classes is None:
                self.set_classes(args[0].classes)
            self.refresh()

        callback = MplCallback(registry=self.callbacks_common,
                               event='spy_classes_modified',
                               callback=updater)
        callback.connect()
        self.cb_classes_modified = callback

        if settings.imshow_enable_rectangle_selector is False:
            return
        try:
            from matplotlib.widgets import RectangleSelector
            self.selector = RectangleSelector(self.axes,
                                              self._select_rectangle,
                                              button=1,
                                              useblit=True,
                                              spancoords='data',
                                              drawtype='box',
                                              rectprops = \
                                                  self.selector_rectprops)
            self.selector.set_active(False)
        except:
            self.selector = None
            msg = 'Failed to create RectangleSelector object. Interactive ' \
              'pixel class labeling will be unavailable.'
            warn(msg)

    def label_region(self, rectangle, class_id):
        '''Assigns all pixels in the rectangle to the specified class.

        Arguments:

            `rectangle` (4-tuple of integers):

                Tuple or list defining the rectangle bounds. Should have the
                form (row_start, row_stop, col_start, col_stop), where the
                stop indices are not included (i.e., the effect is
                `classes[row_start:row_stop, col_start:col_stop] = id`.

            class_id (integer >= 0):

                The class to which pixels will be assigned.

        Returns the number of pixels reassigned (the number of pixels in the
        rectangle whose class has *changed* to `class_id`.
        '''
        if self.classes is None:
            self.classes = np.zeros(self.data_rgb.shape[:2], dtype=np.int16)
        r = rectangle
        n = np.sum(self.classes[r[0]:r[1], r[2]:r[3]] != class_id)
        if n > 0:
            self.classes[r[0]:r[1], r[2]:r[3]] = class_id
            event = SpyMplEvent('spy_classes_modified')
            event.classes = self.classes
            event.nchanged = n
            self.callbacks_common.process('spy_classes_modified', event)
            # Make selection rectangle go away.
            self.selector.to_draw.set_visible(False)
            self.refresh()
            return n
        return 0

    def _select_rectangle(self, event1, event2):
        if event1.inaxes is not self.axes or event2.inaxes is not self.axes:
            self.selection = None
            return
        (r1, c1) = xy_to_rowcol(event1.xdata, event1.ydata)
        (r2, c2) = xy_to_rowcol(event2.xdata, event2.ydata)
        (r1, r2) = sorted([r1, r2])
        (c1, c2) = sorted([c1, c2])
        if (r2 < 0) or (r1 >= self._image_shape[0]) or \
          (c2 < 0) or (c1 >= self._image_shape[1]):
            self.selection = None
            return
        r1 = max(r1, 0)
        r2 = min(r2, self._image_shape[0] - 1)
        c1 = max(c1, 0)
        c2 = min(c2, self._image_shape[1] - 1)
        print('Selected region: [%d: %d, %d: %d]' % (r1, r2 + 1, c1, c2 + 1))
        self.selection = [r1, r2 + 1, c1, c2 + 1]
        self.selector.set_active(False)
        # Make the rectangle display until at least the next event
        self.selector.to_draw.set_visible(True)
        self.selector.update()

    def _guess_mode(self):
        '''Select an appropriate display mode, based on current data.'''
        if self.data_rgb is not None:
            self.set_display_mode('data')
        elif self.classes is not None:
            self.set_display_mode('classes')
        else:
            raise Exception('Unable to display image: no data set.')

    def show_data(self):
        '''Show the image data.'''
        import matplotlib.pyplot as plt
        if self.data_axes is not None:
            msg = 'ImageView.show_data should only be called once.'
            warnings.warn(UserWarning(msg))
            return
        elif self.data_rgb is None:
            raise Exception('Unable to display data: data array not set.')
        if self.axes is not None:
            # A figure has already been created for the view. Make it current.
            plt.figure(self.axes.figure.number)
        self.imshow_data_kwargs['interpolation'] = self._interpolation
        self.data_axes = plt.imshow(self.data_rgb, **self.imshow_data_kwargs)
        if self.axes is None:
            self.axes = self.data_axes.axes

    def show_classes(self):
        '''Show the class values.'''
        import matplotlib.pyplot as plt
        from matplotlib.colors import ListedColormap, NoNorm
        from spectral import get_rgb

        if self.class_axes is not None:
            msg = 'ImageView.show_classes should only be called once.'
            warnings.warn(UserWarning(msg))
            return
        elif self.classes is None:
            raise Exception('Unable to display classes: class array not set.')

        cm = ListedColormap(np.array(self.class_colors) / 255.)
        self._update_class_rgb()
        kwargs = self.imshow_class_kwargs.copy()

        kwargs.update({
            'cmap': cm,
            'vmin': 0,
            'norm': NoNorm(),
            'interpolation': self._interpolation
        })
        if self.axes is not None:
            # A figure has already been created for the view. Make it current.
            plt.figure(self.axes.figure.number)
        self.class_axes = plt.imshow(self.class_rgb, **kwargs)
        if self.axes is None:
            self.axes = self.class_axes.axes
        self.class_axes.set_zorder(1)
        if self.display_mode == 'overlay':
            self.class_axes.set_alpha(self._class_alpha)
        else:
            self.class_axes.set_alpha(1)
        #self.class_axes.axes.set_axis_bgcolor('black')

    def refresh(self):
        '''Updates the displayed data (if it has been shown).'''
        if self.is_shown:
            self._update_class_rgb()
            if self.class_axes is not None:
                self.class_axes.set_data(self.class_rgb)
                self.class_axes.set_interpolation(self._interpolation)
            elif self.display_mode in ('classes', 'overlay'):
                self.show_classes()
            if self.data_axes is not None:
                self.data_axes.set_data(self.data_rgb)
                self.data_axes.set_interpolation(self._interpolation)
            elif self.display_mode in ('data', 'overlay'):
                self.show_data()
            self.axes.figure.canvas.draw()

    def _update_class_rgb(self):
        if self.display_mode == 'overlay':
            self.class_rgb = np.ma.array(self.classes,
                                         mask=(self.classes == 0))
        else:
            self.class_rgb = np.array(self.classes)

    def set_display_mode(self, mode):
        '''`mode` must be one of ("data", "classes", "overlay").'''
        if mode not in ('data', 'classes', 'overlay'):
            raise ValueError('Invalid display mode: ' + repr(mode))
        self.display_mode = mode

        show_data = mode in ('data', 'overlay')
        if self.data_axes is not None:
            self.data_axes.set_visible(show_data)

        show_classes = mode in ('classes', 'overlay')
        if self.classes is not None and self.class_axes is None:
            # Class data values were just set
            self.show_classes()
        if self.class_axes is not None:
            self.class_axes.set_visible(show_classes)
            if mode == 'classes':
                self.class_axes.set_alpha(1)
            else:
                self.class_axes.set_alpha(self._class_alpha)
        self.refresh()

    @property
    def class_alpha(self):
        '''alpha transparency for the class overlay.'''
        return self._class_alpha

    @class_alpha.setter
    def class_alpha(self, alpha):
        if alpha < 0 or alpha > 1:
            raise ValueError('Alpha value must be in range [0, 1].')
        self._class_alpha = alpha
        if self.class_axes is not None:
            self.class_axes.set_alpha(alpha)
        if self.is_shown:
            self.refresh()

    @property
    def interpolation(self):
        '''matplotlib pixel interpolation to use in the image display.'''
        return self._interpolation

    @interpolation.setter
    def interpolation(self, interpolation):
        if interpolation == self._interpolation:
            return
        self._interpolation = interpolation
        if not self.is_shown:
            return
        if self.data_axes is not None:
            self.data_axes.set_interpolation(interpolation)
        if self.class_axes is not None:
            self.class_axes.set_interpolation(interpolation)
        self.refresh()

    def set_title(self, s):
        if self.is_shown:
            self.axes.set_title(s)
            self.refresh()

    def open_zoom(self, center=None, size=None):
        '''Opens a separate window with a zoomed view.
        If a ctrl-lclick event occurs in the original view, the zoomed window
        will pan to the location of the click event.

        Arguments:

            `center` (two-tuple of int):

                Initial (row, col) of the zoomed view.

            `size` (int):

                Width and height (in source image pixels) of the initial
                zoomed view.

        Returns:

        A new ImageView object for the zoomed view.
        '''
        from spectral import settings
        import matplotlib.pyplot as plt
        if size is None:
            size = settings.imshow_zoom_pixel_width
        (nrows, ncols) = self._image_shape
        fig_kwargs = {}
        if settings.imshow_zoom_figure_width is not None:
            width = settings.imshow_zoom_figure_width
            fig_kwargs['figsize'] = (width, width)
        fig = plt.figure(**fig_kwargs)

        view = ImageView(source=self.source)
        view.set_data(self.data, self.bands, **self.rgb_kwargs)
        view.set_classes(self.classes, self.class_colors)
        view.imshow_data_kwargs = self.imshow_data_kwargs.copy()
        kwargs = {'extent': (-0.5, ncols - 0.5, nrows - 0.5, -0.5)}
        view.imshow_data_kwargs.update(kwargs)
        view.imshow_class_kwargs = self.imshow_class_kwargs.copy()
        view.imshow_class_kwargs.update(kwargs)
        view.set_display_mode(self.display_mode)
        view.callbacks_common = self.callbacks_common
        view.show(fignum=fig.number, mode=self.display_mode)
        view.axes.set_xlim(0, size)
        view.axes.set_ylim(size, 0)
        view.interpolation = 'nearest'
        if center is not None:
            view.pan_to(*center)
        view.cb_parent_pan = ParentViewPanCallback(view, self)
        view.cb_parent_pan.connect()
        return view

    def pan_to(self, row, col):
        '''Centers view on pixel coordinate (row, col).'''
        if self.axes is None:
            raise Exception('Cannot pan image until it is shown.')
        (xmin, xmax) = self.axes.get_xlim()
        (ymin, ymax) = self.axes.get_ylim()
        xrange_2 = (xmax - xmin) / 2.0
        yrange_2 = (ymax - ymin) / 2.0
        self.axes.set_xlim(col - xrange_2, col + xrange_2)
        self.axes.set_ylim(row - yrange_2, row + yrange_2)
        self.axes.figure.canvas.draw()

    def zoom(self, scale):
        '''Zooms view in/out (`scale` > 1 zooms in).'''
        (xmin, xmax) = self.axes.get_xlim()
        (ymin, ymax) = self.axes.get_ylim()
        x = (xmin + xmax) / 2.0
        y = (ymin + ymax) / 2.0
        dx = (xmax - xmin) / 2.0 / scale
        dy = (ymax - ymin) / 2.0 / scale

        self.axes.set_xlim(x - dx, x + dx)
        self.axes.set_ylim(y - dy, y + dy)
        self.refresh()

    def format_coord(self, x, y):
        '''Formats pixel coordinate string displayed in the window.'''
        (nrows, ncols) = self._image_shape
        if x < -0.5 or x > ncols - 0.5 or y < -0.5 or y > nrows - 0.5:
            return ""
        (r, c) = xy_to_rowcol(x, y)
        s = 'pixel=[%d,%d]' % (r, c)
        if self.classes is not None:
            try:
                s += ' class=%d' % self.classes[r, c]
            except:
                pass
        return s

    def __str__(self):
        meta = self.data_rgb_meta
        s = 'ImageView object:\n'
        if 'bands' in meta:
            s += '  {0:<20}:  {1}\n'.format("Display bands", meta['bands'])
        if self.interpolation == None:
            interp = "<default>"
        else:
            interp = self.interpolation
        s += '  {0:<20}:  {1}\n'.format("Interpolation", interp)
        if 'rgb range' in meta:
            s += '  {0:<20}:\n'.format("RGB data limits")
            for (c, r) in zip('RGB', meta['rgb range']):
                s += '    {0}: {1}\n'.format(c, str(r))
        return s

    def __repr__(self):
        return str(self)
Beispiel #5
0
class Video_Bbox(object):
    def __init__(self, fig, ax, img_paths, classes):

        self.RS = RectangleSelector(
            ax,
            self.line_select_callback,
            drawtype='box',
            useblit=True,
            button=[1, 3],  # don't use middle button
            minspanx=5,
            minspany=5,
            spancoords='pixels',
            interactive=True)

        fig.canvas.mpl_connect('key_press_event', self.toggle_selector)
        fig.canvas.mpl_connect('draw_event', self.persist_rectangle)
        ax.set_yticklabels([])
        ax.set_xticklabels([])

        self.ax = ax
        self.fig = fig
        self.axradio = plt.axes([0.0, 0.0, 0.2, 1])
        self.radio = RadioButtons(self.axradio, classes)
        self.zoom_scale = 1.2
        self.img_paths = img_paths
        self.zoom_id = fig.canvas.mpl_connect('scroll_event', self.zoom)
        self.axsubmit = plt.axes([0.81, 0.05, 0.1, 0.05])
        self.b_submit = Button(self.axsubmit, 'Submit')
        self.b_submit.on_clicked(self.submit)

        self.index = 0
        img = plt.imread(self.img_paths[self.index])
        self.ax.imshow(img, aspect='auto')

    def line_select_callback(self, eclick, erelease):
        'eclick and erelease are the press and release events'
        x1, y1 = eclick.xdata, eclick.ydata
        x2, y2 = erelease.xdata, erelease.ydata

    def persist_rectangle(self, event):
        if self.RS.active:
            self.RS.update()

    def zoom(self, event):

        if not event.inaxes:
            return
        cur_xlim = self.ax.get_xlim()
        cur_ylim = self.ax.get_ylim()

        xdata = event.xdata  # get event x location
        ydata = event.ydata  # get event y location

        if event.button == 'down':
            # deal with zoom in
            scale_factor = 1 / self.zoom_scale
        elif event.button == 'up':
            # deal with zoom out
            scale_factor = self.zoom_scale
        else:
            # deal with something that should never happen
            scale_factor = 1
            print(event.button)

        new_width = (cur_xlim[1] - cur_xlim[0]) * scale_factor
        new_height = (cur_ylim[1] - cur_ylim[0]) * scale_factor

        relx = (cur_xlim[1] - xdata) / (cur_xlim[1] - cur_xlim[0])
        rely = (cur_ylim[1] - ydata) / (cur_ylim[1] - cur_ylim[0])

        self.ax.set_xlim(
            [xdata - new_width * (1 - relx), xdata + new_width * (relx)])
        self.ax.set_ylim(
            [ydata - new_height * (1 - rely), ydata + new_height * (rely)])
        self.ax.figure.canvas.draw()

    def submit(self, event):
        if not self.is_empty():
            bbox = self.RS.extents
            with open('annotations.csv', 'a') as f:
                csv_writer = csv.writer(f,
                                        delimiter=',',
                                        quotechar='|',
                                        quoting=csv.QUOTE_MINIMAL)
                csv_writer.writerow([
                    os.path.abspath(self.img_paths[self.index]),
                    int(bbox[0]),
                    int(bbox[2]),
                    int(bbox[1]),
                    int(bbox[3]), self.radio.value_selected
                ])
            rect = patches.Rectangle((bbox[0], bbox[2]),
                                     bbox[1] - bbox[0],
                                     bbox[3] - bbox[2],
                                     linewidth=1,
                                     edgecolor='g',
                                     facecolor='g',
                                     alpha=0.4)
            self.ax.add_patch(rect)
            self.RS.to_draw.set_visible(False)
            self.fig.canvas.draw()

    def toggle_selector(self, event):
        if event.key in ['N', 'n']:
            self.ax.clear()
            self.index += 1
            if self.index == len(self.img_paths):
                exit()
            img = plt.imread(self.img_paths[self.index])
            self.ax.imshow(img)
            self.ax.set_yticklabels([])
            self.ax.set_xticklabels([])
            self.fig.canvas.draw()

        if event.key in ['q', 'Q']:
            exit()

    def is_empty(self):
        return self.RS._rect_bbox == (0, 0, 0, 1)
Beispiel #6
0
class ImageView(object):
    '''Class to manage events and data associated with image raster views.

    In most cases, it is more convenient to simply call :func:`~spectral.graphics.spypylab.imshow`,
    which creates, displays, and returns an :class:`ImageView` object. Creating
    an :class:`ImageView` object directly (or creating an instance of a subclass)
    enables additional customization of the image display (e.g., overriding
    default event handlers). If the object is created directly, call the
    :meth:`show` method to display the image. The underlying image display
    functionality is implemented via :func:`matplotlib.pyplot.imshow`.
    '''
    selector_rectprops = dict(facecolor='red', edgecolor = 'black',
                              alpha=0.5, fill=True)
    selector_lineprops = dict(color='black', linestyle='-',
                              linewidth = 2, alpha=0.5)
    def __init__(self, data=None, bands=None, classes=None, source=None,
                 **kwargs):
        '''
        Arguments:

            `data` (ndarray or :class:`SpyFile`):

                The source of RGB bands to be displayed. with shape (R, C, B).
                If the shape is (R, C, 3), the last dimension is assumed to
                provide the red, green, and blue bands (unless the `bands`
                argument is provided). If :math:`B > 3` and `bands` is not
                provided, the first, middle, and last band will be used.

            `bands` (triplet of integers):

                Specifies which bands in `data` should be displayed as red,
                green, and blue, respectively.

            `classes` (ndarray of integers):

                An array of integer-valued class labels with shape (R, C). If
                the `data` argument is provided, the shape must match the first
                two dimensions of `data`.

            `source` (ndarray or :class:`SpyFile`):

                The source of spectral data associated with the image display.
                This optional argument is used to access spectral data (e.g., to
                generate a spectrum plot when a user double-clicks on the image
                display.

        Keyword arguments:

            Any keyword that can be provided to :func:`~spectral.graphics.graphics.get_rgb`
            or :func:`matplotlib.imshow`.
        '''

        import spectral
        from spectral import settings
        self.is_shown = False
        self.imshow_data_kwargs = {'cmap': settings.imshow_float_cmap}
        self.rgb_kwargs = {}
        self.imshow_class_kwargs = {'zorder': 1}

        self.data = data
        self.data_rgb = None
        self.data_rgb_meta = {}
        self.classes = None
        self.class_rgb = None
        self.source = None
        self.bands = bands
        self.data_axes = None
        self.class_axes = None
        self.axes = None
        self._image_shape = None
        self.display_mode = None
        self._interpolation = None
        self.selection = None
        self.interpolation = kwargs.get('interpolation',
                                        settings.imshow_interpolation)
        
        if data is not None:
            self.set_data(data, bands, **kwargs)
        if classes is not None:
            self.set_classes(classes, **kwargs)
        if source is not None:
            self.set_source(source)

        self.class_colors = spectral.spy_colors
 
        self.spectrum_plot_fig_id = None
        self.parent = None
        self.selector = None
        self._on_parent_click_cid = None
        self._class_alpha = settings.imshow_class_alpha

        # Callbacks for events associated specifically with this window.
        self.callbacks = None
        
        # A sharable callback registry for related windows. If this
        # CallbackRegistry is set prior to calling ImageView.show (e.g., by
        # setting it equal to the `callbacks_common` member of another
        # ImageView object), then the registry will be shared. Otherwise, a new
        # callback registry will be created for this ImageView.
        self.callbacks_common = None

        check_disable_mpl_callbacks()

    def set_data(self, data, bands=None, **kwargs):
        '''Sets the data to be shown in the RGB channels.
        
        Arguments:

            `data` (ndarray or SpyImage):

                If `data` has more than 3 bands, the `bands` argument can be
                used to specify which 3 bands to display. `data` will be
                passed to `get_rgb` prior to display.

            `bands` (3-tuple of int):

                Indices of the 3 bands to display from `data`.

        Keyword Arguments:

            Any valid keyword for `get_rgb` or `matplotlib.imshow` can be
            given.
        '''
        from .graphics import _get_rgb_kwargs

        self.data = data
        self.bands = bands

        rgb_kwargs = {}
        for k in _get_rgb_kwargs:
            if k in kwargs:
                rgb_kwargs[k] = kwargs.pop(k)
        self.set_rgb_options(**rgb_kwargs)

        self._update_data_rgb()

        if self._image_shape is None:
            self._image_shape = data.shape[:2]
        elif data.shape[:2] != self._image_shape:
            raise ValueError('Image shape is inconsistent with previously ' \
                             'set data.')
        self.imshow_data_kwargs.update(kwargs)
        if 'interpolation' in self.imshow_data_kwargs:
            self.interpolation = self.imshow_data_kwargs['interpolation']
            self.imshow_data_kwargs.pop('interpolation')

        if len(kwargs) > 0 and self.is_shown:
            msg = 'Keyword args to set_data only have an effect if ' \
              'given before the image is shown.'
            warnings.warn(UserWarning(msg))
        if self.is_shown:
            self.refresh()

    def set_rgb_options(self, **kwargs):
        '''Sets parameters affecting RGB display of data.

        Accepts any keyword supported by :func:`~spectral.graphics.graphics.get_rgb`.
        '''
        from .graphics import _get_rgb_kwargs

        for k in kwargs:
            if k not in _get_rgb_kwargs:
                raise ValueError('Unexpected keyword: {0}'.format(k))
        self.rgb_kwargs = kwargs.copy()
        if self.is_shown:
            self._update_data_rgb()
            self.refresh()
        
    def _update_data_rgb(self):
        '''Regenerates the RGB values for display.'''
        from .graphics import get_rgb_meta

        (self.data_rgb, self.data_rgb_meta) = \
          get_rgb_meta(self.data, self.bands, **self.rgb_kwargs)

        # If it is a gray-scale image, only keep the first RGB component so
        # matplotlib imshow's cmap can still be used.
        if self.data_rgb_meta['mode'] == 'monochrome' and \
           self.data_rgb.ndim ==3:
          (self.bands is not None and len(self.bands) == 1)

    def set_classes(self, classes, colors=None, **kwargs):
        '''Sets the array of class values associated with the image data.

        Arguments:

            `classes` (ndarray of int):

                `classes` must be an integer-valued array with the same
                number rows and columns as the display data (if set).

            `colors`: (array or 3-tuples):

                Color triplets (with values in the range [0, 255]) that
                define the colors to be associatd with the integer indices
                in `classes`.

        Keyword Arguments:

            Any valid keyword for `matplotlib.imshow` can be provided.
        '''
        from .graphics import _get_rgb_kwargs
        self.classes = classes
        if classes is None:
            return
        if self._image_shape is None:
            self._image_shape = classes.shape[:2]
        elif classes.shape[:2] != self._image_shape:
            raise ValueError('Class data shape is inconsistent with ' \
                             'previously set data.')
        if colors is not None:
            self.class_colors = colors

        kwargs = dict([item for item in list(kwargs.items()) if item[0] not in \
                       _get_rgb_kwargs])
        self.imshow_class_kwargs.update(kwargs)

        if 'interpolation' in self.imshow_class_kwargs:
            self.interpolation = self.imshow_class_kwargs['interpolation']
            self.imshow_class_kwargs.pop('interpolation')

        if len(kwargs) > 0 and self.is_shown:
            msg = 'Keyword args to set_classes only have an effect if ' \
              'given before the image is shown.'
            warnings.warn(UserWarning(msg))
        if self.is_shown:
            self.refresh()

    def set_source(self, source):
        '''Sets the image data source (used for accessing spectral data).

        Arguments:

            `source` (ndarray or :class:`SpyFile`):

                The source for spectral data associated with the view.
        '''
        self.source = source
    
    def show(self, mode=None, fignum=None):
        '''Renders the image data.

        Arguments:

            `mode` (str):

                Must be one of:

                    "data":          Show the data RGB

                    "classes":       Shows indexed color for `classes`

                    "overlay":       Shows class colors overlaid on data RGB.

                If `mode` is not provided, a mode will be automatically
                selected, based on the data set in the ImageView.

            `fignum` (int):

                Figure number of the matplotlib figure in which to display
                the ImageView. If not provided, a new figure will be created.
        '''
        import matplotlib.pyplot as plt
        from spectral import settings

        if self.is_shown:
            msg = 'ImageView.show should only be called once.'
            warnings.warn(UserWarning(msg))
            return

        set_mpl_interactive()

        kwargs = {}
        if fignum is not None:
            kwargs['num'] = fignum
        if settings.imshow_figure_size is not None:
            kwargs['figsize'] = settings.imshow_figure_size
        plt.figure(**kwargs)
            
        if self.data_rgb is not None:
            self.show_data()
        if self.classes is not None:
            self.show_classes()

        if mode is None:
            self._guess_mode()
        else:
            self.set_display_mode(mode)

        self.axes.format_coord = self.format_coord

        self.init_callbacks()
        self.is_shown = True

    def init_callbacks(self):
        '''Creates the object's callback registry and default callbacks.'''
        from spectral import settings
        from matplotlib.cbook import CallbackRegistry
        
        self.callbacks = CallbackRegistry()

        # callbacks_common may have been set to a shared external registry
        # (e.g., to the callbacks_common member of another ImageView object). So
        # don't create it if it has already been set.
        if self.callbacks_common is None:
            self.callbacks_common = CallbackRegistry()

        # Keyboard callback
        self.cb_mouse = ImageViewMouseHandler(self)
        self.cb_mouse.connect()

        # Mouse callback
        self.cb_keyboard = ImageViewKeyboardHandler(self)
        self.cb_keyboard.connect()

        # Class update event callback
        def updater(*args, **kwargs):
            if self.classes is None:
                self.set_classes(args[0].classes)
            self.refresh()
        callback = MplCallback(registry=self.callbacks_common,
                               event='spy_classes_modified',
                               callback=updater)
        callback.connect()
        self.cb_classes_modified = callback


        if settings.imshow_enable_rectangle_selector is False:
            return
        try:
            from matplotlib.widgets import RectangleSelector
            self.selector = RectangleSelector(self.axes,
                                              self._select_rectangle,
                                              button=1,
                                              useblit=True,
                                              spancoords='data',
                                              drawtype='box',
                                              rectprops = \
                                                  self.selector_rectprops)
            self.selector.set_active(False)
        except:
            self.selector = None
            msg = 'Failed to create RectangleSelector object. Interactive ' \
              'pixel class labeling will be unavailable.'
            warn(msg)

    def label_region(self, rectangle, class_id):
        '''Assigns all pixels in the rectangle to the specified class.

        Arguments:

            `rectangle` (4-tuple of integers):

                Tuple or list defining the rectangle bounds. Should have the
                form (row_start, row_stop, col_start, col_stop), where the
                stop indices are not included (i.e., the effect is
                `classes[row_start:row_stop, col_start:col_stop] = id`.

            class_id (integer >= 0):

                The class to which pixels will be assigned.

        Returns the number of pixels reassigned (the number of pixels in the
        rectangle whose class has *changed* to `class_id`.
        '''
        if self.classes is None:
            self.classes = np.zeros(self.data_rgb.shape[:2], dtype=np.int16)
        r = rectangle
        n = np.sum(self.classes[r[0]:r[1], r[2]:r[3]] != class_id)
        if n > 0:
            self.classes[r[0]:r[1], r[2]:r[3]] = class_id
            event = SpyMplEvent('spy_classes_modified')
            event.classes = self.classes
            event.nchanged = n
            self.callbacks_common.process('spy_classes_modified', event)
            # Make selection rectangle go away.
            self.selector.to_draw.set_visible(False)
            self.refresh()
            return n
        return 0

    def _select_rectangle(self, event1, event2):
        if event1.inaxes is not self.axes or event2.inaxes is not self.axes:
            self.selection = None
            return
        (r1, c1) = xy_to_rowcol(event1.xdata, event1.ydata)
        (r2, c2) = xy_to_rowcol(event2.xdata, event2.ydata)
        (r1, r2) = sorted([r1, r2])
        (c1, c2) = sorted([c1, c2])
        if (r2 < 0) or (r1 >= self._image_shape[0]) or \
          (c2 < 0) or (c1 >= self._image_shape[1]):
          self.selection = None
          return
        r1 = max(r1, 0)
        r2 = min(r2, self._image_shape[0] - 1)
        c1 = max(c1, 0)
        c2 = min(c2, self._image_shape[1] - 1)
        print('Selected region: [%d: %d, %d: %d]' % (r1, r2 + 1, c1, c2 + 1))
        self.selection = [r1, r2 + 1, c1, c2 + 1]
        self.selector.set_active(False)
        # Make the rectangle display until at least the next event
        self.selector.to_draw.set_visible(True)
        self.selector.update()
    
    def _guess_mode(self):
        '''Select an appropriate display mode, based on current data.'''
        if self.data_rgb is not None:
            self.set_display_mode('data')
        elif self.classes is not None:
            self.set_display_mode('classes')
        else:
            raise Exception('Unable to display image: no data set.')

    def show_data(self):
        '''Show the image data.'''
        import matplotlib.pyplot as plt
        if self.data_axes is not None:
            msg = 'ImageView.show_data should only be called once.'
            warnings.warn(UserWarning(msg))
            return
        elif self.data_rgb is None:
            raise Exception('Unable to display data: data array not set.')
        if self.axes is not None:
            # A figure has already been created for the view. Make it current.
            plt.figure(self.axes.figure.number)
        self.imshow_data_kwargs['interpolation'] = self._interpolation
        self.data_axes = plt.imshow(self.data_rgb, **self.imshow_data_kwargs)
        if self.axes is None:
            self.axes = self.data_axes.axes

    def show_classes(self):
        '''Show the class values.'''
        import matplotlib.pyplot as plt
        from matplotlib.colors import ListedColormap, NoNorm
        from spectral import get_rgb

        if self.class_axes is not None:
            msg = 'ImageView.show_classes should only be called once.'
            warnings.warn(UserWarning(msg))
            return
        elif self.classes is None:
            raise Exception('Unable to display classes: class array not set.')

        cm = ListedColormap(np.array(self.class_colors) / 255.)
        self._update_class_rgb()
        kwargs = self.imshow_class_kwargs.copy()

        kwargs.update({'cmap': cm, 'vmin': 0, 'norm': NoNorm(),
                       'interpolation': self._interpolation})
        if self.axes is not None:
            # A figure has already been created for the view. Make it current.
            plt.figure(self.axes.figure.number)
        self.class_axes = plt.imshow(self.class_rgb, **kwargs)
        if self.axes is None:
            self.axes = self.class_axes.axes
        self.class_axes.set_zorder(1)
        if self.display_mode == 'overlay':
            self.class_axes.set_alpha(self._class_alpha)
        else:
            self.class_axes.set_alpha(1)
        #self.class_axes.axes.set_axis_bgcolor('black')

    def refresh(self):
        '''Updates the displayed data (if it has been shown).'''
        if self.is_shown:
            self._update_class_rgb()
            if self.class_axes is not None:
                self.class_axes.set_data(self.class_rgb)
                self.class_axes.set_interpolation(self._interpolation)
            elif self.display_mode in ('classes', 'overlay'):
                self.show_classes()
            if self.data_axes is not None:
                self.data_axes.set_data(self.data_rgb)
                self.data_axes.set_interpolation(self._interpolation)
            elif self.display_mode in ('data', 'overlay'):
                self.show_data()
            self.axes.figure.canvas.draw()

    def _update_class_rgb(self):
        if self.display_mode == 'overlay':
            self.class_rgb = np.ma.array(self.classes, mask=(self.classes==0))
        else:
            self.class_rgb = np.array(self.classes)
        
    def set_display_mode(self, mode):
        '''`mode` must be one of ("data", "classes", "overlay").'''
        if mode not in ('data', 'classes', 'overlay'):
            raise ValueError('Invalid display mode: ' + repr(mode))
        self.display_mode = mode

        show_data = mode in ('data', 'overlay')
        if self.data_axes is not None:
            self.data_axes.set_visible(show_data)

        show_classes = mode in ('classes', 'overlay')
        if self.classes is not None and self.class_axes is None:
            # Class data values were just set
            self.show_classes()
        if self.class_axes is not None:
            self.class_axes.set_visible(show_classes)
            if mode == 'classes':
                self.class_axes.set_alpha(1)
            else:
                self.class_axes.set_alpha(self._class_alpha)
        self.refresh()

    @property
    def class_alpha(self):
        '''alpha transparency for the class overlay.'''
        return self._class_alpha

    @class_alpha.setter
    def class_alpha(self, alpha):
        if alpha < 0 or alpha > 1:
            raise ValueError('Alpha value must be in range [0, 1].')
        self._class_alpha = alpha
        if self.class_axes is not None:
            self.class_axes.set_alpha(alpha)
        if self.is_shown:
            self.refresh()

    @property
    def interpolation(self):
        '''matplotlib pixel interpolation to use in the image display.'''
        return self._interpolation

    @interpolation.setter
    def interpolation(self, interpolation):
        if interpolation == self._interpolation:
            return
        self._interpolation = interpolation
        if not self.is_shown:
            return
        if self.data_axes is not None:
            self.data_axes.set_interpolation(interpolation)
        if self.class_axes is not None:
            self.class_axes.set_interpolation(interpolation)
        self.refresh()

    def set_title(self, s):
        if self.is_shown:
            self.axes.set_title(s)
            self.refresh()

    def open_zoom(self, center=None, size=None):
        '''Opens a separate window with a zoomed view.
        If a ctrl-lclick event occurs in the original view, the zoomed window
        will pan to the location of the click event.

        Arguments:

            `center` (two-tuple of int):

                Initial (row, col) of the zoomed view.

            `size` (int):

                Width and height (in source image pixels) of the initial
                zoomed view.

        Returns:

        A new ImageView object for the zoomed view.
        '''
        from spectral import settings
        import matplotlib.pyplot as plt
        if size is None:
            size = settings.imshow_zoom_pixel_width
        (nrows, ncols) = self._image_shape
        fig_kwargs = {}
        if settings.imshow_zoom_figure_width is not None:
            width = settings.imshow_zoom_figure_width
            fig_kwargs['figsize'] = (width, width)
        fig = plt.figure(**fig_kwargs)

        view = ImageView(source=self.source)
        view.set_data(self.data, self.bands, **self.rgb_kwargs)
        view.set_classes(self.classes, self.class_colors)
        view.imshow_data_kwargs = self.imshow_data_kwargs.copy()
        kwargs = {'extent': (-0.5, ncols - 0.5, nrows - 0.5, -0.5)}
        view.imshow_data_kwargs.update(kwargs)
        view.imshow_class_kwargs = self.imshow_class_kwargs.copy()
        view.imshow_class_kwargs.update(kwargs)
        view.set_display_mode(self.display_mode)
        view.callbacks_common = self.callbacks_common
        view.show(fignum=fig.number, mode=self.display_mode)
        view.axes.set_xlim(0, size)
        view.axes.set_ylim(size, 0)
        view.interpolation = 'nearest'
        if center is not None:
            view.pan_to(*center)
        view.cb_parent_pan = ParentViewPanCallback(view, self)
        view.cb_parent_pan.connect()
        return view

    def pan_to(self, row, col):
        '''Centers view on pixel coordinate (row, col).'''
        if self.axes is None:
            raise Exception('Cannot pan image until it is shown.')
        (xmin, xmax) = self.axes.get_xlim()
        (ymin, ymax) = self.axes.get_ylim()
        xrange_2 = (xmax - xmin) / 2.0
        yrange_2 = (ymax - ymin) / 2.0
        self.axes.set_xlim(col - xrange_2, col + xrange_2)
        self.axes.set_ylim(row - yrange_2, row + yrange_2)
        self.axes.figure.canvas.draw()

    def zoom(self, scale):
        '''Zooms view in/out (`scale` > 1 zooms in).'''
        (xmin, xmax) = self.axes.get_xlim()
        (ymin, ymax) = self.axes.get_ylim()
        x = (xmin + xmax) / 2.0
        y = (ymin + ymax) / 2.0
        dx = (xmax - xmin) / 2.0 / scale
        dy = (ymax - ymin) / 2.0 / scale

        self.axes.set_xlim(x - dx, x + dx)
        self.axes.set_ylim(y - dy, y + dy)
        self.refresh()


    def format_coord(self, x, y):
        '''Formats pixel coordinate string displayed in the window.'''
        (nrows, ncols) = self._image_shape
        if x < -0.5 or x > ncols - 0.5 or y < -0.5 or y > nrows - 0.5:
            return ""
        (r, c) = xy_to_rowcol(x, y)
        s = 'pixel=[%d,%d]' % (r, c)
        if self.classes is not None:
            try:
                s += ' class=%d' % self.classes[r, c]
            except:
                pass
        return s

    def __str__(self):
        meta = self.data_rgb_meta
        s = 'ImageView object:\n'
        if 'bands' in meta:
            s += '  {0:<20}:  {1}\n'.format("Display bands", meta['bands'])
        if self.interpolation == None:
            interp = "<default>"
        else:
            interp = self.interpolation
        s += '  {0:<20}:  {1}\n'.format("Interpolation", interp)
        if 'rgb range' in meta:
            s += '  {0:<20}:\n'.format("RGB data limits")
            for (c, r) in zip('RGB', meta['rgb range']):
                s += '    {0}: {1}\n'.format(c, str(r))
        return s

    def __repr__(self):
        return str(self)
class InteractiveCut(object):

    def __init__(self, slice_plot, canvas, ws_title):
        self.slice_plot = slice_plot
        self._canvas = canvas
        self._ws_title = ws_title
        self._en_unit = slice_plot.get_slice_cache().energy_axis.e_unit
        self._en_from_meV = EnergyUnits(self._en_unit).factor_from_meV()

        self.horizontal = None
        self.connect_event = [None, None, None, None]
        # We need to access the CutPlotterPresenter instance of the particular CutPlot (window) we are using
        # But there is no way to get without changing the active category then calling the GlobalFigureManager.
        # So we create a new temporary here. After the first time we plot a 1D plot, the correct category is set
        # and we can get the correct CutPlot instance and its CutPlotterPresenter
        self._cut_plotter_presenter = CutPlotterPresenter()
        self._is_initial_cut_plotter_presenter = True
        self._rect_pos_cache = [0, 0, 0, 0, 0, 0]
        self.rect = RectangleSelector(self._canvas.figure.gca(), self.plot_from_mouse_event,
                                      drawtype='box', useblit=True,
                                      button=[1, 3], spancoords='pixels', interactive=True)

        self.connect_event[3] = self._canvas.mpl_connect('draw_event', self.redraw_rectangle)
        self._canvas.draw()

    def plot_from_mouse_event(self, eclick, erelease):
        # Make axis orientation sticky, until user selects entirely new rectangle.
        rect_pos = [eclick.x, eclick.y, erelease.x, erelease.y,
                    abs(erelease.x - eclick.x), abs(erelease.y - eclick.y)]
        rectangle_changed = all([abs(rect_pos[i] - self._rect_pos_cache[i]) > 0.1 for i in range(6)])
        if rectangle_changed:
            self.horizontal = abs(erelease.x - eclick.x) > abs(erelease.y - eclick.y)
        self.plot_cut(eclick.xdata, erelease.xdata, eclick.ydata, erelease.ydata)
        self.connect_event[2] = self._canvas.mpl_connect('button_press_event', self.clicked)
        self._rect_pos_cache = rect_pos

    def plot_cut(self, x1, x2, y1, y2, store=False):
        if x2 > x1 and y2 > y1:
            ax, integration_start, integration_end = self.get_cut_parameters((x1, y1), (x2, y2))
            units = self._canvas.figure.gca().get_yaxis().units if self.horizontal else \
                self._canvas.figure.gca().get_xaxis().units
            integration_axis = Axis(units, integration_start, integration_end, 0, self._en_unit)
            cut = Cut(ax, integration_axis, None, None)
            self._cut_plotter_presenter.plot_interactive_cut(str(self._ws_title), cut, store)
            self._cut_plotter_presenter.set_is_icut(True)
            if self._is_initial_cut_plotter_presenter:
                # First time we've plotted a 1D cut - get the true CutPlotterPresenter
                from mslice.plotting.pyplot import GlobalFigureManager
                self._cut_plotter_presenter = GlobalFigureManager.get_active_figure().plot_handler._cut_plotter_presenter
                self._is_initial_cut_plotter_presenter = False
                GlobalFigureManager.disable_make_current()
            self._cut_plotter_presenter.store_icut(self)

    def get_cut_parameters(self, pos1, pos2):
        start = pos1[not self.horizontal]
        end = pos2[not self.horizontal]
        units = self._canvas.figure.gca().get_xaxis().units if self.horizontal else \
            self._canvas.figure.gca().get_yaxis().units
        step = get_limits(get_workspace_handle(self._ws_title), units)[2] * self._en_from_meV
        ax = Axis(units, start, end, step, self._en_unit)
        integration_start = pos1[self.horizontal]
        integration_end = pos2[self.horizontal]
        return ax, integration_start, integration_end

    def clicked(self, event):
        self.connect_event[0] = self._canvas.mpl_connect('motion_notify_event',
                                                         lambda x: self.plot_cut(*self.rect.extents))
        self.connect_event[1] = self._canvas.mpl_connect('button_release_event', self.end_drag)

    def end_drag(self, event):
        self._canvas.mpl_disconnect(self.connect_event[0])
        self._canvas.mpl_disconnect(self.connect_event[1])

    def redraw_rectangle(self, event):
        if self.rect.active:
            self.rect.update()

    def save_cut(self):
        x1, x2, y1, y2 = self.rect.extents
        self.plot_cut(x1, x2, y1, y2, store=True)
        self.update_workspaces()
        ax, integration_start, integration_end = self.get_cut_parameters((x1, y1), (x2, y2))
        return output_workspace_name(str(self._ws_title), integration_start, integration_end)

    def update_workspaces(self):
        self.slice_plot.update_workspaces()

    def clear(self):
        self._cut_plotter_presenter.set_is_icut(False)
        self.rect.set_active(False)
        for event in self.connect_event:
            self._canvas.mpl_disconnect(event)
        self._canvas.draw()

    def flip_axis(self):
        self.horizontal = not self.horizontal
        self.plot_cut(*self.rect.extents)

    def window_closing(self):
        self.slice_plot.toggle_interactive_cuts()
        self.slice_plot.plot_window.action_interactive_cuts.setChecked(False)
Beispiel #8
0
class InteractiveCut(object):
    def __init__(self, slice_plot, canvas, ws_title):
        self.slice_plot = slice_plot
        self._canvas = canvas
        self._ws_title = ws_title

        self.horizontal = None
        self.connect_event = [None, None, None, None]
        self._cut_plotter_presenter = CutPlotterPresenter()
        self._rect_pos_cache = [0, 0, 0, 0, 0, 0]
        self.rect = RectangleSelector(self._canvas.figure.gca(),
                                      self.plot_from_mouse_event,
                                      drawtype='box',
                                      useblit=True,
                                      button=[1, 3],
                                      spancoords='pixels',
                                      interactive=True)

        self.connect_event[3] = self._canvas.mpl_connect(
            'draw_event', self.redraw_rectangle)
        self._canvas.draw()

    def plot_from_mouse_event(self, eclick, erelease):
        # Make axis orientation sticky, until user selects entirely new rectangle.
        rect_pos = [
            eclick.x, eclick.y, erelease.x, erelease.y,
            abs(erelease.x - eclick.x),
            abs(erelease.y - eclick.y)
        ]
        rectangle_changed = all([
            abs(rect_pos[i] - self._rect_pos_cache[i]) > 0.1 for i in range(6)
        ])
        if rectangle_changed:
            self.horizontal = abs(erelease.x - eclick.x) > abs(erelease.y -
                                                               eclick.y)
        self.plot_cut(eclick.xdata, erelease.xdata, eclick.ydata,
                      erelease.ydata)
        self.connect_event[2] = self._canvas.mpl_connect(
            'button_press_event', self.clicked)
        self._rect_pos_cache = rect_pos

    def plot_cut(self, x1, x2, y1, y2, store=False):
        if x2 > x1 and y2 > y1:
            ax, integration_start, integration_end = self.get_cut_parameters(
                (x1, y1), (x2, y2))
            units = self._canvas.figure.gca().get_yaxis().units if self.horizontal else \
                self._canvas.figure.gca().get_xaxis().units
            integration_axis = Axis(units, integration_start, integration_end,
                                    0)
            self._cut_plotter_presenter.plot_interactive_cut(
                str(self._ws_title), ax, integration_axis, store)
            self._cut_plotter_presenter.store_icut(self._ws_title, self)

    def get_cut_parameters(self, pos1, pos2):
        start = pos1[not self.horizontal]
        end = pos2[not self.horizontal]
        units = self._canvas.figure.gca().get_xaxis().units if self.horizontal else \
            self._canvas.figure.gca().get_yaxis().units
        step = get_limits(get_workspace_handle(self._ws_title), units)[2]
        ax = Axis(units, start, end, step)
        integration_start = pos1[self.horizontal]
        integration_end = pos2[self.horizontal]
        return ax, integration_start, integration_end

    def clicked(self, event):
        self.connect_event[0] = self._canvas.mpl_connect(
            'motion_notify_event', lambda x: self.plot_cut(*self.rect.extents))
        self.connect_event[1] = self._canvas.mpl_connect(
            'button_release_event', self.end_drag)

    def end_drag(self, event):
        self._canvas.mpl_disconnect(self.connect_event[0])
        self._canvas.mpl_disconnect(self.connect_event[1])

    def redraw_rectangle(self, event):
        if self.rect.active:
            self.rect.update()

    def save_cut(self):
        x1, x2, y1, y2 = self.rect.extents
        self.plot_cut(x1, x2, y1, y2, store=True)
        self.update_workspaces()
        ax, integration_start, integration_end = self.get_cut_parameters(
            (x1, y1), (x2, y2))
        return output_workspace_name(str(self._ws_title), integration_start,
                                     integration_end)

    def update_workspaces(self):
        self.slice_plot.update_workspaces()

    def clear(self):
        self._cut_plotter_presenter.set_is_icut(self._ws_title, False)
        self.rect.set_active(False)
        for event in self.connect_event:
            self._canvas.mpl_disconnect(event)
        self._canvas.draw()

    def flip_axis(self):
        self.horizontal = not self.horizontal
        self.plot_cut(*self.rect.extents)

    def window_closing(self):
        self.slice_plot.toggle_interactive_cuts()
        self.slice_plot.plot_window.action_interactive_cuts.setChecked(False)
Beispiel #9
0
class IsoCenter_Child(QMain, IsoCenter.Ui_IsoCenter):
    "Class that contains subroutines to define isocenter from Lynx image"

    def __init__(self, parent, owner):
        super(IsoCenter_Child, self).__init__()

        self.Owner = owner
        self.setupUi(self)
        self.setStyleSheet(parent.styleSheet())

        self.parent = parent
        self.canvas = self.Display_IsoCenter.canvas
        self.toolbar = self.canvas.toolbar

        # Connect buttons
        self.Button_LoadSpot.clicked.connect(self.load)
        self.Button_detectIsoCenter.clicked.connect(self.drawRect)
        self.Button_SetIsoCenter.clicked.connect(self.LockIsoCenter)
        self.Button_Done.clicked.connect(self.Done)

        # Works only after first rectangle was drawn
        try:
            self.Button_detectIsoCenter.clicked.connect(self.initclick)
        except AttributeError:
            pass

        # Flags and Containers
        self.Image = None
        self.press = None
        self.rects = []
        self.target_markers = []

        # Flags
        self.IsoCenter_flag = False

        # Lists for isocenter markers in canvas
        self.target_markers = []

    def drawRect(self):

        # Remove previous spotdetections
        for item in self.target_markers:
            if type(item) == matplotlib.contour.QuadContourSet:
                [artist.set_visible(False) for artist in item.collections]
            else:
                item.set_visible(False)

        # change cursor style
        QApplication.setOverrideCursor(Qt.CrossCursor)

        # Rectangle selector for 2d fit
        rectprops = dict(facecolor='orange',
                         edgecolor=None,
                         alpha=0.2,
                         fill=True)
        # drawtype is 'box' or 'line' or 'none'
        self.RS = RectangleSelector(
            self.canvas.axes,
            self.line_select_callback,
            drawtype='box',
            rectprops=rectprops,
            button=[1],  # don't use middle button
            minspanx=5,
            minspany=5,
            spancoords='pixels',
            useblit=True,
            interactive=True)
        self.canvas.draw()
        self.bg = self.canvas.copy_from_bbox(self.RS.ax.bbox)
        self.RS.set_visible(True)

        ext = (0, 4, 0, 1)
        self.RS.draw_shape(ext)

        # Update displayed handles
        self.RS._corner_handles.set_data(*self.RS.corners)
        self.RS._edge_handles.set_data(*self.RS.edge_centers)
        self.RS._center_handle.set_data(*self.RS.center)
        for artist in self.RS.artists:
            self.RS.ax.draw_artist(artist)
            artist.set_animated(False)
        self.canvas.draw()
        self.cid = self.canvas.mpl_connect("button_press_event",
                                           self.initclick)

    def line_select_callback(self, eclick, erelease):
        x1, y1 = eclick.xdata, eclick.ydata
        x2, y2 = erelease.xdata, erelease.ydata

        p1 = (x1, y1)
        p2 = (x2, y2)

        self.spotDetect(p1, p2)

    def initclick(self, evt):
        self.RS.background = self.bg
        self.RS.update()
        for artist in self.RS.artists:
            artist.set_animated(True)
        self.canvas.mpl_disconnect(self.cid)

    def load(self):
        "load radiography image of beam IsoCenter"

        # get filename from full path and display
        fname = Qfile.getOpenFileName(self, 'Open file', "",
                                      "Dicom files (*.dcm *tiff *tif)")[0]
        try:
            # import imagedata with regard to filetype
            if fname.endswith("dcm"):
                meta = dcm.read_file(fname)
                self.Image = RadiographyImage(fname, meta.pixel_array,
                                              meta.PixelSpacing)

            elif fname.endswith("tif") or fname.endswith("tiff"):
                pw, okx = QInputDialog.getDouble(self,
                                                 'Pixel Spacing',
                                                 'pixel width (mm):',
                                                 0.05,
                                                 decimals=2)
                self.Image = RadiographyImage(fname, tifffile.imread(fname),
                                              pw)

            self.Text_Filename.setText(fname)  # display filename
            self.canvas.axes.imshow(self.Image.array,
                                    cmap='gray',
                                    zorder=1,
                                    origin='lower')
            self.canvas.draw()
            logging.info('{:s} imported as Isocenter'.format(fname))

        except Exception:
            logging.ERROR("{:s} could not be opened".format(fname))
            self.IsoCenter_flag = False
            return 0

    def LockIsoCenter(self):
        """ Read current values from sliders/ spot location text fields
        and set as final isocenter coordinates to be used for the
        actual positioning"""
        self.SpotTxt_x.setStyleSheet("color: rgb(255, 0, 0);")
        self.SpotTxt_y.setStyleSheet("color: rgb(255, 0, 0);")
        # Raise flag for checksum check later
        self.IsoCenter_flag = True

        # Function to pass IsoCenter values to parent window
        self.Owner.return_isocenter(
            self.Image, [self.SpotTxt_x.value(),
                         self.SpotTxt_y.value()])
        logging.info('Isocenter coordinates confirmed')

    def update_crosshair(self):
        """Get value from Spinboxes and update all
        markers/plots if that value is changed"""
        x = self.SpotTxt_x.value()
        y = self.SpotTxt_y.value()

        # Update Plot Markers
        self.hline.set_ydata(y)
        self.vline.set_xdata(x)

        # Update Plot
        self.Display_IsoCenter.canvas.draw()

        self.SpotTxt_x.setStyleSheet("color: rgb(0, 0, 0);")
        self.SpotTxt_y.setStyleSheet("color: rgb(0, 0, 0);")
        self.IsoCenter_flag = False

    def spotDetect(self, p1, p2):
        " Function that is invoked by ROI selection, autodetects earpin"

        # Restore old cursor
        QApplication.restoreOverrideCursor()

        # Get ROI limits from drawn rectangle corners
        x = int(min(p1[0], p2[0]) + 0.5)
        y = int(min(p1[1], p2[1]) + 0.5)
        width = int(np.abs(p1[0] - p2[0]) + 0.5)
        height = int(np.abs(p1[1] - p2[1]) + 0.5)

        subset = self.Image.array[y:y + height, x:x + width]

        # Calculate fit function values
        try:
            popt, pcov = find_center(subset, x, y, sigma=5.0)
            logging.info('Detected coordinates for isocenter:'
                         'x = {:2.1f}, y = {:2.1f}'.format(popt[1], popt[2]))
        except Exception:
            logging.error('Autodetection of Landmark in ROI failed.')
            # self.TxtEarpinX.setValue(0)
            # self.TxtEarpinY.setValue(0)
            return 0

        xx, yy, xrange, yrange = array2mesh(self.Image.array)
        data_fitted = twoD_Gaussian((xx, yy), *popt)

        # Print markers into image
        ax = self.canvas.axes
        self.target_markers.append(
            ax.contour(xx, yy, data_fitted.reshape(yrange, xrange), 5))
        self.target_markers.append(ax.axvline(popt[1], 0, ax.get_ylim()[1]))
        self.target_markers.append(ax.axhline(popt[2], 0, ax.get_xlim()[1]))
        self.canvas.draw()

        self.SpotTxt_x.setValue(popt[1])
        self.SpotTxt_y.setValue(popt[2])

        logging.info('Coordinates of IsoCenter set to '
                     'x = {:.1f}, y = {:.1f}'.format(popt[1], popt[2]))

    def Done(self):
        "Ends IsoCenter Definition and closes Child"
        # Also check whether all values were locked to main window
        if not self.IsoCenter_flag:
            Hint = QMessage()
            Hint.setStandardButtons(QMessage.No | QMessage.Yes)
            Hint.setIcon(QMessage.Information)
            Hint.setText("Some values have not been locked or were modified!"
                         "\nProceed?")
            answer = Hint.exec_()
            if answer == QMessage.Yes:
                self.close()
        else:
            self.close()
class COCO_dataset_generator(object): 
    
    def __init__(self, fig, ax, img_dir, classes, model_path, json_file):
    
        self.RS = RectangleSelector(ax, self.line_select_callback,
                                       drawtype='box', useblit=True,
                                       button=[1, 3],  # don't use middle button
                                       minspanx=5, minspany=5,
                                       spancoords='pixels',
                                       interactive=True) 
                                         
        ax.set_yticklabels([])
        ax.set_xticklabels([])
        
        #self.classes, self.img_paths, _ = read_JSON_file(json_file)
        with open(classes, 'r') as f:
            self.classes, img_paths = sorted([x.strip().split(',')[0] for x in f.readlines()]), glob.glob(os.path.abspath(os.path.join(img_dir, '*.jpg')))
        plt.tight_layout()

        self.ax = ax
        self.fig = fig
        self.axradio = plt.axes([0.0, 0.0, 0.1, 1])
        self.radio = RadioButtons(self.axradio, self.classes)
        self.zoom_scale = 1.2
        self.zoom_id = self.fig.canvas.mpl_connect('scroll_event', self.zoom) 
        self.keyboard_id = self.fig.canvas.mpl_connect('key_press_event', self.onkeyboard)
        self.selected_poly = False
        self.axsave = plt.axes([0.81, 0.05, 0.1, 0.05])
        self.b_save = Button(self.axsave, 'Save')
        self.b_save.on_clicked(self.save)        
        self.objects, self.existing_patches, self.existing_rects = [], [], []
        self.num_pred = 0
        if json_file is None:
            self.images, self.annotations = [], [] 
            self.index = 0
            self.ann_id = 0
        else:
            with open(json_file, 'r') as g:
                d = json.loads(g.read())
            self.images, self.annotations = d['images'], d['annotations']
            self.index = len(self.images)
            self.ann_id = len(self.annotations)
        prev_files = [x['file_name'] for x in self.images]
        for i, f in enumerate(img_paths):
            im = Image.open(f)
            width, height = im.size
            dic = {'file_name': f, 'id': self.index+i, 'height': height, 'width': width} 
            if f not in prev_files:
                self.images.append(dic)
            else:
                self.index+=1
        image = plt.imread(self.images[self.index]['file_name'])
        self.ax.imshow(image, aspect='auto')

        if not args['no_feedback']:
            from mask_rcnn.get_json_config import get_demo_config 
            from mask_rcnn import model as modellib
            from mask_rcnn.visualize_cv2 import random_colors
        
            self.config = get_demo_config(len(self.classes)-1, True)

            if 'config_path' in args:
                self.config.from_json(args['config_path'])
        
            plt.connect('draw_event', self.persist)
        
            # Create model object in inference mode.
            self.model = modellib.MaskRCNN(mode="inference", model_dir='/'.join(args['weights_path'].split('/')[:-2]), config=self.config)

            # Load weights trained on MS-COCO
            self.model.load_weights(args['weights_path'], by_name=True)
        
            r = self.model.detect([image], verbose=0)[0]
     
            # Number of instances
            N = r['rois'].shape[0]
        
            masks = r['masks']
        
            # Show area outside image boundaries.
            height, width = image.shape[:2]
        
            class_ids, scores, rois = r['class_ids'], r['scores'], r['rois'],
       
            for i in range(N):
            
                # Label
                class_id = class_ids[i]
                score = scores[i] if scores is not None else None
                label = self.classes[class_id-1]
                pat = patches.Rectangle((rois[i][1], rois[i][0]), rois[i][3]-rois[i][1], rois[i][2]-rois[i][0], linewidth=1, edgecolor='r',facecolor='r', alpha=0.4)
                rect = self.ax.add_patch(pat)
                        
                self.objects.append(label)
                self.existing_patches.append(pat.get_bbox().get_points())
                self.existing_rects.append(pat)
            self.num_pred = len(self.objects)
    
    def line_select_callback(self, eclick, erelease):
        'eclick and erelease are the press and release events'
        x1, y1 = eclick.xdata, eclick.ydata
        x2, y2 = erelease.xdata, erelease.ydata
    
    def zoom(self, event):
        
        if not event.inaxes:
            return
        cur_xlim = self.ax.get_xlim()
        cur_ylim = self.ax.get_ylim()

        xdata = event.xdata # get event x location
        ydata = event.ydata # get event y location

        if event.button == 'down':
            # deal with zoom in
            scale_factor = 1 / self.zoom_scale
        elif event.button == 'up':
            # deal with zoom out
            scale_factor = self.zoom_scale
        else:
            # deal with something that should never happen
            scale_factor = 1
            print (event.button)

        new_width = (cur_xlim[1] - cur_xlim[0]) * scale_factor
        new_height = (cur_ylim[1] - cur_ylim[0]) * scale_factor

        relx = (cur_xlim[1] - xdata)/(cur_xlim[1] - cur_xlim[0])
        rely = (cur_ylim[1] - ydata)/(cur_ylim[1] - cur_ylim[0])

        self.ax.set_xlim([xdata - new_width * (1-relx), xdata + new_width * (relx)])
        self.ax.set_ylim([ydata - new_height * (1-rely), ydata + new_height * (rely)])
        self.ax.figure.canvas.draw()

    def save(self, event):
        
        data = {'images':self.images[:self.index+1], 'annotations':self.annotations, 'categories':[], 'classes': self.classes}

        with open('output.json', 'w') as outfile:
            json.dump(data, outfile)
    
    def persist(self, event):
        if self.RS.active:
            self.RS.update()
        
    def onkeyboard(self, event):
        
        if not event.inaxes:
            return
        elif event.key == 'a':
            for i, ((xmin, ymin), (xmax, ymax)) in enumerate(self.existing_patches):
                if xmin<=event.xdata<=xmax and ymin<=event.ydata<=ymax:
                    self.radio.set_active(self.classes.index(self.objects[i]))
                    self.RS.set_active(True)
                    self.rectangle = self.existing_rects[i]
                    self.rectangle.set_visible(False)
                    coords = self.rectangle.get_bbox().get_points()
                    self.RS.extents = [coords[0][0], coords[1][0], coords[0][1], coords[1][1]]
                    self.RS.to_draw.set_visible(True)
                    self.fig.canvas.draw()
                    self.existing_rects.pop(i)
                    self.existing_patches.pop(i)
                    self.objects.pop(i)
                    fig.canvas.draw()
                    break
            
        elif event.key == 'i':
            b = self.RS.extents # xmin, xmax, ymin, ymax
            b = [int(x) for x in b]
            if b[1]-b[0]>0 and b[3]-b[2]>0:
                poly = [b[0], b[2], b[0], b[3], b[1], b[3], b[1], b[2], b[0], b[2]]
                area = (b[1]-b[0])*(b[3]-b[2])
                bbox = [b[0], b[2], b[1], b[3]]
                dic2 = {'segmentation': [poly], 'area': area, 'iscrowd':0, 'image_id':self.index, 'bbox':bbox, 'category_id': self.classes.index(self.radio.value_selected)+1, 'id': self.ann_id}
                if dic2 not in self.annotations:
                    self.annotations.append(dic2)
                self.ann_id+=1
                rect = patches.Rectangle((b[0],b[2]),b[1]-b[0],b[3]-b[2],linewidth=1,edgecolor='g',facecolor='g', alpha=0.4)
                self.ax.add_patch(rect)
                
                self.RS.set_active(False)
                
                self.fig.canvas.draw()
                self.RS.set_active(True)
        elif event.key in ['N', 'n']:
            self.ax.clear()
            self.index+=1
            if (len(self.objects)==self.num_pred):
                self.images.pop(self.index-1)
                self.index-=1
            if self.index==len(self.images):
                exit()
            image = plt.imread(self.images[self.index]['file_name'])
            self.ax.imshow(image)
            self.ax.set_yticklabels([])
            self.ax.set_xticklabels([])
            r = self.model.detect([image], verbose=0)[0]
     
            # Number of instances
            N = r['rois'].shape[0]
        
            masks = r['masks']
        
            # Show area outside image boundaries.
            height, width = image.shape[:2]
        
            class_ids, scores, rois = r['class_ids'], r['scores'], r['rois'],
            self.existing_rects, self.existing_patches, self.objects = [], [], []
            for i in range(N):
                
                # Label
                class_id = class_ids[i]
                score = scores[i] if scores is not None else None
                label = self.classes[class_id-1]
                pat = patches.Rectangle((rois[i][1], rois[i][0]), rois[i][3]-rois[i][1], rois[i][2]-rois[i][0], linewidth=1, edgecolor='r',facecolor='r', alpha=0.4)
                rect = self.ax.add_patch(pat)
                        
                self.objects.append(label)

                self.existing_patches.append(pat.get_bbox().get_points())
                self.existing_rects.append(pat)
            
            self.num_pred = len(self.objects)
            self.fig.canvas.draw()
            
        elif event.key in ['q','Q']:
            exit()
Beispiel #11
0
class Axes2D(object):
    def __init__(self, ax, parent):
        self.ax = ax
        self.parent = parent

        self.params_2d = {}
        self.y, self.x1, self.x2, self.data = [], [], [], []
        self.rectangle_coordinates = None
        self.RectangleSelector = None
        self.cut_window = None
        self.colormap_window = None
        self.apply_log_status = False
        self.plot_obj = None
        self.Ranges = Plot2DRangesHandler(self)

    def set_rectangle(self):
        self.RectangleSelector = RectangleSelector(
            self.ax,
            self.line_select_callback,
            drawtype='box',
            button=[1],  # don't use middle button
            minspanx=5,
            minspany=5,
            spancoords='pixels',
            interactive=True)

        def update_rs(event):
            if self.RectangleSelector is not None:
                if self.RectangleSelector.active and self.parent.status == 2:
                    self.RectangleSelector.update()

        self.parent.mpl_connect('draw_event', update_rs)
        if self.rectangle_coordinates is not None:
            x1, y1, x2, y2 = self.rectangle_coordinates
            self.RectangleSelector.extents = (x1, x2, y1, y2)
            self.RectangleSelector.update()

    def line_select_callback(self, eclick, erelease):
        """eclick and erelease are the press and release events"""
        assert self.cut_window is not None
        self.rectangle_coordinates = eclick.xdata, eclick.ydata, erelease.xdata, erelease.ydata
        self.cut_window.canvas.update_cut_plot()

    def context_menu(self, event):
        menu = QMenu()
        y = self.parent.parent().height()
        position = self.parent.mapFromParent(QPoint(event.x, y - event.y))
        parameter_menu = menu.addMenu('Plot parameters')
        change_colormap_action = parameter_menu.addAction("Change colormap")
        change_colormap_action.triggered.connect(self.open_colormap_window)

        log_action_name = "Disable log" if self.apply_log_status else "Apply log"
        change_log_action = parameter_menu.addAction(log_action_name)
        change_log_action.triggered.connect(self.change_log_status)

        reset_action = parameter_menu.addAction("Reset parameters")
        reset_action.triggered.connect(self.reset_parameters)

        redraw_action = menu.addAction("Redraw graph")
        redraw_action.triggered.connect(self.redraw_2d_plot)

        menu.addSeparator()
        open_cut_window_action = menu.addAction("Open cut window")
        open_cut_window_action.triggered.connect(self.open_cut_window)
        open_cut_window_action.setEnabled(self.cut_window is None)
        menu.exec_(self.parent.parent().mapToGlobal(position))

    def open_colormap_window(self, event):
        range_init, range_whole = self.Ranges.get_ranges_for_colormap(self.y)
        self.colormap_window = ColormapWindow(range_init,
                                              range_whole,
                                              title='Colormap')
        self.colormap_window.set_callback(self.colormap_callback)
        self.colormap_window.show()

    def redraw_2d_plot(self):
        self.ax.cla()
        self.plot_obj = self.ax.imshow(self.y, **self.params_2d)
        if self.cut_window:
            self.set_rectangle()

    def colormap_callback(self, range_):
        self.Ranges.update_params(range_)
        self.redraw_2d_plot()
        self.parent.draw()

    def change_log_status(self, event):
        self.apply_log_status = not self.apply_log_status
        if self.apply_log_status:
            self.y = self.apply_log(self.data)
        else:
            self.y = self.data
        self.Ranges.change_regime()
        self.redraw_2d_plot()
        self.parent.draw()

    def reset_parameters(self, event):
        self.apply_log_status = False
        self.params_2d.pop('vmin', None)
        self.params_2d.pop('vmax', None)
        self.y = self.data
        self.redraw_2d_plot()
        self.parent.draw()

    @staticmethod
    def apply_log(data):
        min_value = max([0.1, np.amin(data)])
        max_value = np.amax(data)
        return np.log(np.clip(data, min_value, max_value))

    def update_plot(self, obj, file):
        self.data = obj[()]
        if self.apply_log_status:
            self.y = self.apply_log(self.data)
        else:
            self.y = self.data

        try:
            x_ax = file[obj.attrs['x_axis']][()]
            y_ax = file[obj.attrs['y_axis']][()]
            assert (len(x_ax), len(y_ax)) == self.y.shape \
                   or (len(y_ax), len(x_ax)) == self.y.shape, 'shapes are wrong'
            if (len(y_ax), len(x_ax)) == self.y.shape:
                y_ax, x_ax = x_ax, y_ax
            self.x1 = y_ax
            self.x2 = x_ax
        except (AssertionError, TypeError, KeyError):
            self.params_2d.pop('extent', None)
            self.x1 = list(range(0, self.y.shape[1]))
            self.x2 = list(range(0, self.y.shape[0]))

        except Exception as er:
            print(er)
            return

        self.params_2d.update(
            dict(extent=[self.x1[0], self.x1[-1], self.x2[0], self.x2[-1]]))

        if self.plot_obj is not None:
            self.plot_obj.set_data(self.y)
            self.plot_obj.set_extent(self.params_2d['extent'])
            self.ax.relim()  # Recalculate limits
            self.ax.autoscale_view(True, True, True)
        else:
            self.plot_obj = self.ax.imshow(self.y, **self.params_2d)
        self.parent.draw()
        if self.cut_window:
            self.cut_window.canvas.update_cut_plot()

    def open_cut_window(self):
        self.cut_window = CutWindow(self)
        self.set_rectangle()
        if self.rectangle_coordinates:
            self.cut_window.canvas.update_cut_plot()

    def on_closing_cut_window(self):
        self.RectangleSelector.set_visible(False)
        self.RectangleSelector.update()
        self.RectangleSelector = None
        self.cut_window = None
class Video_Bbox(object):
    def __init__(self, fig, ax, args):

        self.RS = RectangleSelector(
            ax,
            self.line_select_callback,
            drawtype='box',
            useblit=True,
            button=[1, 3],  # don't use middle button
            minspanx=5,
            minspany=5,
            spancoords='pixels',
            interactive=True)

        fig.canvas.mpl_connect('key_press_event', self.toggle_selector)
        fig.canvas.mpl_connect('draw_event', self.persist_rectangle)
        ax.set_yticklabels([])
        ax.set_xticklabels([])

        self.ax = ax
        self.fig = fig
        self.cls_name = args.class_name
        self.img_paths = sorted(glob.glob(
            os.path.join(os.path.basename(args.video_file), '*')),
                                key=lambda x: int(os.path.basename(x[:-4])))

        self.index = 0
        img = plt.imread(self.img_paths[self.index])
        self.ax.imshow(img)

    def persist_rectangle(self, event):
        if self.RS.active:
            self.RS.update()

    def line_select_callback(self, eclick, erelease):
        'eclick and erelease are the press and release events'
        x1, y1 = eclick.xdata, eclick.ydata
        x2, y2 = erelease.xdata, erelease.ydata

    def toggle_selector(self, event):
        if event.key in ['N', 'n']:

            if not self.is_empty():
                bbox = self.RS.extents

                with open('annotations.csv', 'a') as f:
                    csv_writer = csv.writer(f,
                                            delimiter=',',
                                            quotechar='|',
                                            quoting=csv.QUOTE_MINIMAL)
                    csv_writer.writerow([
                        os.path.abspath(self.img_paths[self.index]),
                        int(bbox[0]),
                        int(bbox[2]),
                        int(bbox[1]),
                        int(bbox[3]), self.cls_name
                    ])
                print('Frame %d/%d' % (self.index + 1, len(self.img_paths)))
            self.ax.clear()
            self.index += 1
            img = plt.imread(self.img_paths[self.index])
            self.ax.imshow(img)
            self.RS.to_draw.set_visible(True)
            self.ax.set_yticklabels([])
            self.ax.set_xticklabels([])
            self.fig.canvas.draw()

        if event.key in ['q', 'Q']:
            exit()

    def is_empty(self):
        return self.RS._rect_bbox == (0, 0, 0, 1)