コード例 #1
0
class image_segmenter:
    def __init__(self,
                 img_dir,
                 classes,
                 overlay_alpha=.5,
                 figsize=(10, 10),
                 scroll_to_zoom=True,
                 zoom_scale=1.1):
        """
        TODO allow for intializing with a shape instead of an image
        
        parameters
        ----------
        img_dir : string
            path to directory 'images' that contains 'train/' and images are in 'train/'
        classes : Int or list
            Number of classes or a list of class names
        ensure_rgba : boolean
            whether to force the displayed image to have an alpha channel to enable transparent overlay
        zoom_scale : float or None
            How much to scale the image per scroll. If you do this I recommend using jupyterlab-sidecar in order
            to prevent the page from scrolling. or checking in on: https://github.com/matplotlib/ipympl/issues/222
            To disable zoom set this to None.
        """

        self.img_dir = img_dir

        if not path.isdir(path.join(self.img_dir, 'train')):
            raise ValueError(
                f"{self.img_dir} must exist and contain the the folder 'train'"
            )
        self.img_dir = path.join(self.img_dir, 'train')
        #ensure that there is a sibling directory named masks
        self.mask_dir = path.join(
            self.img_dir.rsplit('train_imgs/', 1)[0], 'train_masks/train')
        #         self.mask_dir = path.join(self.img_dir.rslit(, 'masks/train/')
        if not os.path.isdir(self.mask_dir):
            os.makedirs(self.mask_dir)
#         elif not os.path.isdir(self.mask_dir):
#             raise ValueError(f'{self.mask_dir} already exists and is not a folder')

        self.image_paths = []
        for type_ in VALID_IMAGE_TYPES:
            self.image_paths += (
                glob.glob(self.img_dir.rstrip('/') + f'/*.{type_}'))
        self.shape = None

        plt.ioff()  # see https://github.com/matplotlib/matplotlib/issues/17013
        self.fig = plt.figure(figsize=figsize)
        self.ax = self.fig.gca()
        lineprops = {'color': 'black', 'linewidth': 1, 'alpha': 0.8}
        self.lasso = LassoSelector(self.ax,
                                   self.onselect,
                                   lineprops=lineprops,
                                   button=1,
                                   useblit=False)
        self.lasso.set_visible(True)
        self.fig.canvas.mpl_connect('button_press_event', self.onclick)
        self.fig.canvas.mpl_connect('button_release_event', self._release)
        self.panhandler = panhandler(self.fig)

        # setup lasso stuff

        plt.ion()

        if isinstance(classes, int):
            classes = np.arange(classes)
        if len(classes) <= 10:
            self.colors = 'tab10'
        elif len(classes) <= 20:
            self.colors = 'tab20'
        else:
            raise ValueError(
                f'Currently only up to 20 classes are supported, you tried to use {len(classes)} classes'
            )

        self.colors = np.vstack([[0, 0, 0],
                                 plt.get_cmap(self.colors)(np.arange(
                                     len(classes)))[:, :3]])

        self.class_dropdown = widgets.Dropdown(
            options=[(str(classes[i]), i) for i in range(len(classes))],
            value=0,
            description='Class:',
            disabled=False,
        )
        self.lasso_button = widgets.Button(
            description='lasso select',
            disabled=False,
            button_style=
            'success',  # 'success', 'info', 'warning', 'danger' or ''
            icon='mouse-pointer',  # (FontAwesome names without the `fa-` prefix)
        )
        self.flood_button = widgets.Button(
            description='flood fill',
            disabled=False,
            button_style='',  # 'success', 'info', 'warning', 'danger' or ''
            icon='fill-drip',  # (FontAwesome names without the `fa-` prefix)
        )

        self.erase_check_box = widgets.Checkbox(value=False,
                                                description='Erase Mode',
                                                disabled=False,
                                                indent=False)

        self.reset_button = widgets.Button(
            description='reset',
            disabled=False,
            button_style='',  # 'success', 'info', 'warning', 'danger' or ''
            icon='refresh',  # (FontAwesome names without the `fa-` prefix)
        )
        self.save_button = widgets.Button(description='save mask',
                                          button_style='',
                                          icon='floppy-o')
        self.next_button = widgets.Button(description='next image',
                                          button_style='',
                                          icon='arrow-right')
        self.prev_button = widgets.Button(description='previous image',
                                          button_style='',
                                          icon='arrow-left',
                                          disabled=True)
        self.reset_button.on_click(self.reset)
        self.save_button.on_click(self.save_mask)
        self.next_button.on_click(self._change_image_idx)
        self.prev_button.on_click(self._change_image_idx)

        def button_click(button):
            if button.description == 'flood fill':
                self.flood_button.button_style = 'success'
                self.lasso_button.button_style = ''
                self.lasso.set_active(False)
            else:
                self.flood_button.button_style = ''
                self.lasso_button.button_style = 'success'
                self.lasso.set_active(True)

        self.lasso_button.on_click(button_click)
        self.flood_button.on_click(button_click)
        self.overlay_alpha = overlay_alpha
        self.indices = None
        self.new_image(0)

        #gotta do this after creating the image, and the image needs to come after all the buttons
        if zoom_scale is not None:
            self.disconnect_scroll = zoom_factory(self.ax,
                                                  base_scale=zoom_scale)

    def _change_image_idx(self, button):
        if button is self.next_button:
            if self.img_idx + 1 < len(self.image_paths):
                self.img_idx += 1
                self.save_mask()
                self.new_image(self.img_idx)

                if self.img_idx == len(self.image_paths):
                    self.next_button.disabled = True
                self.prev_button.disabled = False
        elif button is self.prev_button:
            if self.img_idx >= 1:
                self.img_idx -= 1
                self.save_mask()
                self.new_image(self.img_idx)

                if self.img_idx == 0:
                    self.prev_button.disabled = True

                self.next_button.disabled = False

    def new_image(self, img_idx):
        self.indices = None
        self.img = io.imread(self.image_paths[img_idx])
        self.img_idx = img_idx
        img_path = self.image_paths[self.img_idx]
        self.ax.set_title(os.path.basename(img_path))
        self.mask_path = self.mask_dir + f'/{os.path.basename(img_path)}'

        if self.img.shape != self.shape:
            self.shape = self.img.shape
            pix_x = np.arange(self.shape[0])
            pix_y = np.arange(self.shape[1])
            xv, yv = np.meshgrid(pix_y, pix_x)
            self.pix = np.vstack((xv.flatten(), yv.flatten())).T
            self.displayed = self.ax.imshow(self.img)
            #ensure that the _nav_stack is empty
            self.fig.canvas.toolbar._nav_stack.clear()
            #add the initial view to the stack so that the home button works.
            self.fig.canvas.toolbar.push_current()
            if os.path.exists(self.mask_path):
                self.class_mask = io.imread(self.mask_path)
            else:
                self.class_mask = np.zeros([self.shape[0], self.shape[1]],
                                           dtype=np.uint8)
        else:
            self.displayed.set_data(self.img)
            if os.path.exists(self.mask_path):
                self.class_mask = io.imread(self.mask_path)
                # should probs check that the first two dimensions are the same as the img
            else:
                self.class_mask[:, :] = 0
            self.fig.canvas.toolbar.home()
        self.updateArray()

    def _release(self, event):
        self.panhandler.release(event)

    def reset(self, *args):
        self.displayed.set_data(self.img)
        self.class_mask[:, :] = -1
        self.fig.canvas.draw()

    def onclick(self, event):
        """
        handle clicking to remove already added stuff
        """
        if event.button == 1:
            if event.xdata is not None and not self.lasso.active:
                # transpose x and y bc imshow transposes
                self.indices = flood(
                    self.class_mask,
                    (np.int(event.ydata), np.int(event.xdata)))
                self.updateArray()
        elif event.button == 3:
            self.panhandler.press(event)

    def updateArray(self):
        array = self.displayed.get_array().data

        if self.erase_check_box.value:
            if self.indices is not None:
                self.class_mask[self.indices] = 0
                array[self.indices] = self.img[self.indices]
        elif self.indices is not None:
            self.class_mask[self.indices] = self.class_dropdown.value + 1
            # https://en.wikipedia.org/wiki/Alpha_compositing#Straight_versus_premultiplied
            c_overlay = self.colors[self.class_mask[
                self.indices]] * 255 * self.overlay_alpha
            array[self.indices] = (c_overlay + self.img[self.indices] *
                                   (1 - self.overlay_alpha))
        else:
            # new image and we found a class mask
            # so redraw entire array where class != 0
            idx = self.class_mask != 0
            c_overlay = self.colors[
                self.class_mask[idx]] * 255 * self.overlay_alpha
            array[idx] = (c_overlay + self.img[idx] * (1 - self.overlay_alpha))
        self.displayed.set_data(array)

    def onselect(self, verts):
        self.verts = verts
        p = Path(verts)

        self.indices = p.contains_points(self.pix, radius=0).reshape(450, 540)

        self.updateArray()
        self.fig.canvas.draw_idle()

    def render(self):
        layers = [widgets.HBox([self.lasso_button, self.flood_button])]
        layers.append(
            widgets.HBox(
                [self.reset_button, self.class_dropdown,
                 self.erase_check_box]))
        layers.append(self.fig.canvas)
        layers.append(
            widgets.HBox(
                [self.save_button, self.prev_button, self.next_button]))
        return widgets.VBox(layers)

    def save_mask(self, save_if_no_nonzero=False):
        """
        save_if_no_nonzero : boolean
            Whether to save if class_mask only contains 0s
        """
        if (save_if_no_nonzero or np.any(self.class_mask != 0)):
            if os.path.splitext(self.mask_path)[1] in ['jpg', 'jpeg']:
                io.imsave(self.mask_path,
                          self.class_mask,
                          check_contrast=False,
                          quality=100)
            else:
                io.imsave(self.mask_path,
                          self.class_mask,
                          check_contrast=False)

    def _ipython_display_(self):
        display(self.render())
コード例 #2
0
class image_segmenter:
    """
    Manually segment an image with the lasso selector.
    """
    def __init__(
            self,
            img,
            nclasses=1,
            mask=None,
            mask_colors=None,
            mask_alpha=0.75,
            lineprops=None,
            figsize=(10, 10),
            **kwargs,
    ):
        """
        Create an image segmenter. Any ``kwargs`` will be passed through to the ``imshow``
        call that displays *img*.

        parameters
        ----------
        img : array_like
            A valid argument to imshow
        nclasses : int, default 1
        mask: arraylike, optional
            If you want to pre-seed the mask
        mask_colors : None, color, or array of colors, optional
            the colors to use for each class. Unselected regions will always be totally transparent
        mask_alpha : float, default .75
            The alpha values to use for selected regions. This will always override the alpha values
            in mask_colors if any were passed
        lineprops : dict, default: None
            lineprops passed to LassoSelector. If None the default values are:
            {"color": "black", "linewidth": 1, "alpha": 0.8}
        figsize : (float, float), optional
            passed to plt.figure
        **kwargs:
            All other kwargs will passed to the imshow command for the image
        """
        # ensure mask colors is iterable and the same length as the number of classes
        # choose colors from default color cycle?

        self.mask_alpha = mask_alpha

        if mask_colors is None:
            # this will break if there are more than 10 classes
            if nclasses <= 10:
                self.mask_colors = to_rgba_array(
                    list(TABLEAU_COLORS)[:nclasses])
            else:
                # up to 949 classes. Hopefully that is always enough....
                self.mask_colors = to_rgba_array(list(XKCD_COLORS)[:nclasses])
        else:
            self.mask_colors = to_rgba_array(np.atleast_1d(mask_colors))
            # should probably check the shape here
        self.mask_colors[:, -1] = self.mask_alpha

        self._img = np.asarray(img)

        if mask is None:
            self.mask = np.zeros(self._img.shape[:2])
        else:
            self.mask = mask

        self._overlay = np.zeros((*self._img.shape[:2], 4))
        self.nclasses = nclasses
        for i in range(nclasses + 1):
            idx = self.mask == i
            if i == 0:
                self._overlay[idx] = [0, 0, 0, 0]
            else:
                self._overlay[idx] = self.mask_colors[i - 1]
        with ioff:
            self.fig = figure(figsize=figsize)
            self.ax = self.fig.gca()
            self.displayed = self.ax.imshow(self._img, **kwargs)
            self._mask = self.ax.imshow(self._overlay)

        if lineprops is None:
            lineprops = {"color": "black", "linewidth": 1, "alpha": 0.8}
        useblit = False if "ipympl" in get_backend().lower() else True
        self.lasso = LassoSelector(self.ax,
                                   self._onselect,
                                   lineprops=lineprops,
                                   useblit=useblit)
        self.lasso.set_visible(True)

        pix_x = np.arange(self._img.shape[0])
        pix_y = np.arange(self._img.shape[1])
        xv, yv = np.meshgrid(pix_y, pix_x)
        self.pix = np.vstack((xv.flatten(), yv.flatten())).T

        self.ph = panhandler(self.fig)
        self.disconnect_zoom = zoom_factory(self.ax)
        self.current_class = 1
        self.erasing = False

    def _onselect(self, verts):
        self.verts = verts
        p = Path(verts)
        self.indices = p.contains_points(self.pix,
                                         radius=0).reshape(self.mask.shape)
        if self.erasing:
            self.mask[self.indices] = 0
            self._overlay[self.indices] = [0, 0, 0, 0]
        else:
            self.mask[self.indices] = self.current_class
            self._overlay[self.indices] = self.mask_colors[self.current_class -
                                                           1]

        self._mask.set_data(self._overlay)
        self.fig.canvas.draw_idle()

    def _ipython_display_(self):
        display(self.fig.canvas)
コード例 #3
0
ファイル: maskeditor.py プロジェクト: awacha/cct
class MaskEditor(ToolWindow, DoubleFileChooserDialog):
    def __init__(self, *args, **kwargs):
        self.mask = None
        self._undo_stack = []
        self._im = None
        self._selector = None
        self._cursor = None
        self.exposureloader = None
        self.plot2d = None
        ToolWindow.__init__(self, *args, **kwargs)
        DoubleFileChooserDialog.__init__(
            self, self.widget, 'Open mask file...', 'Save mask file...', [('Mask files', '*.mat'), ('All files', '*')],
            self.instrument.config['path']['directories']['mask'],
            os.path.abspath(self.instrument.config['path']['directories']['mask']),
        )

    def init_gui(self, *args, **kwargs):
        self.exposureloader = ExposureLoader(self.instrument)
        self.builder.get_object('loadexposure_expander').add(self.exposureloader)
        self.exposureloader.connect('open', self.on_loadexposure)
        self.plot2d = PlotImageWidget()
        self.builder.get_object('plotbox').pack_start(self.plot2d.widget, True, True, 0)
        self.builder.get_object('toolbar').set_sensitive(False)

    def on_loadexposure(self, exposureloader: ExposureLoader, im: Exposure):
        if self.mask is None:
            self.mask = im.mask
        self._im = im
        self.plot2d.set_image(im.intensity)
        self.plot2d.set_mask(self.mask)
        self.builder.get_object('toolbar').set_sensitive(True)

    def on_new(self, button):
        if self._im is None or self.mask is None:
            return False
        self.mask = np.ones_like(self.mask)
        self.plot2d.set_mask(self.mask)
        self.set_last_filename(None)

    def on_open(self, button):
        filename = self.get_open_filename()
        if filename is not None:
            mask = loadmat(filename)
            self.mask = mask[[k for k in mask.keys() if not k.startswith('__')][0]]
            self.plot2d.set_mask(self.mask)

    def on_save(self, button):
        filename = self.get_last_filename()
        if filename is None:
            return self.on_saveas(button)
        maskname = os.path.splitext(os.path.split(filename)[1])[0]
        savemat(filename, {maskname: self.mask})

    def on_saveas(self, button):
        filename = self.get_save_filename(None)
        if filename is not None:
            self.on_save(button)

    def suggest_filename(self):
        return 'mask_dist_{0.year:d}{0.month:02d}{0.day:02d}.mat'.format(datetime.date.today())

    def on_selectcircle_toggled(self, button):
        if button.get_active():
            self.set_sensitive(False, 'Ellipse selection not ready',
                               ['new_button', 'save_button', 'saveas_button', 'open_button', 'undo_button',
                                'selectrectangle_button', 'selectpolygon_button', 'pixelhunting_button',
                                'loadexposure_expander', 'close_button', self.plot2d.toolbar,
                                self.plot2d.settings_expander])
            while self.plot2d.toolbar.mode != '':
                # turn off zoom, pan, etc. modes.
                self.plot2d.toolbar.zoom()
            self._selector = EllipseSelector(self.plot2d.axis,
                                             self.on_ellipse_selected,
                                             rectprops={'facecolor': 'white', 'edgecolor': 'none', 'alpha': 0.7,
                                                        'fill': True, 'zorder': 10},
                                             button=[1, ],
                                             interactive=False, lineprops={'zorder': 10})
            self._selector.state.add('square')
            self._selector.state.add('center')
        else:
            assert isinstance(self._selector, EllipseSelector)
            self._selector.set_active(False)
            self._selector.set_visible(False)
            self._selector = None
            self.plot2d.replot(keepzoom=False)
            self.set_sensitive(True)

    def on_ellipse_selected(self, pos1, pos2):
        # pos1 and pos2 are mouse button press and release events, with xdata and ydata carrying
        # the two opposite corners of the bounding box of the circle. These are NOT the exact
        # button presses and releases!
        row = np.arange(self.mask.shape[0])[:, np.newaxis]
        column = np.arange(self.mask.shape[1])[np.newaxis, :]
        row0 = 0.5 * (pos1.ydata + pos2.ydata)
        col0 = 0.5 * (pos1.xdata + pos2.xdata)
        r2 = ((pos2.xdata - pos1.xdata) ** 2 + (pos2.ydata - pos1.ydata) ** 2) / 8
        tobemasked = (row - row0) ** 2 + (column - col0) ** 2 <= r2
        self._undo_stack.append(self.mask)
        if self.builder.get_object('mask_button').get_active():
            self.mask &= ~tobemasked
        elif self.builder.get_object('unmask_button').get_active():
            self.mask |= tobemasked
        elif self.builder.get_object('invertmask_button').get_active():
            self.mask[tobemasked] = ~self.mask[tobemasked]
        else:
            pass
        self.builder.get_object('selectcircle_button').set_active(False)
        self.plot2d.set_mask(self.mask)

    def on_selectrectangle_toggled(self, button):
        if button.get_active():
            self.set_sensitive(False, 'Rectangle selection not ready',
                               ['new_button', 'save_button', 'saveas_button', 'open_button', 'undo_button',
                                'selectcircle_button', 'selectpolygon_button', 'pixelhunting_button',
                                'loadexposure_expander', 'close_button', self.plot2d.toolbar,
                                self.plot2d.settings_expander])
            while self.plot2d.toolbar.mode != '':
                # turn off zoom, pan, etc. modes.
                self.plot2d.toolbar.zoom()
            self._selector = RectangleSelector(self.plot2d.axis,
                                               self.on_rectangle_selected,
                                               rectprops={'facecolor': 'white', 'edgecolor': 'none', 'alpha': 0.7,
                                                          'fill': True, 'zorder': 10},
                                               button=[1, ],
                                               interactive=False, lineprops={'zorder': 10})
        else:
            self._selector.set_active(False)
            self._selector.set_visible(False)
            self._selector = None
            self.plot2d.replot(keepzoom=False)
            self.set_sensitive(True)

    def on_rectangle_selected(self, pos1, pos2):
        # pos1 and pos2 are mouse button press and release events, with xdata and ydata
        # carrying the two opposite corners of the bounding box of the rectangle. These
        # are NOT the exact button presses and releases!
        row = np.arange(self.mask.shape[0])[:, np.newaxis]
        column = np.arange(self.mask.shape[1])[np.newaxis, :]
        tobemasked = ((row >= min(pos1.ydata, pos2.ydata)) & (row <= max(pos1.ydata, pos2.ydata)) &
                      (column >= min(pos1.xdata, pos2.xdata)) & (column <= max(pos1.xdata, pos2.xdata)))
        self._undo_stack.append(self.mask)
        if self.builder.get_object('mask_button').get_active():
            self.mask = self.mask & (~tobemasked)
        elif self.builder.get_object('unmask_button').get_active():
            self.mask = self.mask | tobemasked
        elif self.builder.get_object('invertmask_button').get_active():
            self.mask[tobemasked] = ~self.mask[tobemasked]
        else:
            pass
        self.builder.get_object('selectrectangle_button').set_active(False)
        self.plot2d.set_mask(self.mask)

    def on_selectpolygon_toggled(self, button):
        if button.get_active():
            self.set_sensitive(False, 'Polygon selection not ready',
                               ['new_button', 'save_button', 'saveas_button', 'open_button', 'undo_button',
                                'selectrectangle_button', 'selectcircle_button', 'pixelhunting_button',
                                'loadexposure_expander', 'close_button', self.plot2d.toolbar,
                                self.plot2d.settings_expander])
            while self.plot2d.toolbar.mode != '':
                # turn off zoom, pan, etc. modes.
                self.plot2d.toolbar.zoom()
            self._selector = LassoSelector(self.plot2d.axis,
                                           self.on_polygon_selected,
                                           lineprops={'color': 'white', 'zorder': 10},
                                           button=[1, ],
                                           )
        else:
            self._selector.set_active(False)
            self._selector.set_visible(False)
            self._selector = None
            self.plot2d.replot(keepzoom=False)
            self.set_sensitive(True)

    def on_polygon_selected(self, vertices):
        path = Path(vertices)
        col, row = np.meshgrid(np.arange(self.mask.shape[1]),
                               np.arange(self.mask.shape[0]))
        points = np.vstack((col.flatten(), row.flatten())).T
        tobemasked = path.contains_points(points).reshape(self.mask.shape)
        self._undo_stack.append(self.mask)
        if self.builder.get_object('mask_button').get_active():
            self.mask = self.mask & (~tobemasked)
        elif self.builder.get_object('unmask_button').get_active():
            self.mask = self.mask | tobemasked
        elif self.builder.get_object('invertmask_button').get_active():
            self.mask[tobemasked] = ~self.mask[tobemasked]
        else:
            pass
        self.plot2d.set_mask(self.mask)
        self.builder.get_object('selectpolygon_button').set_active(False)

    def on_mask_toggled(self, button):
        pass

    def on_unmask_toggled(self, button):
        pass

    def on_invertmask_toggled(self, button):
        pass

    def on_pixelhunting_toggled(self, button):
        if button.get_active():
            self._cursor = Cursor(self.plot2d.axis, useblit=False, color='white', lw=1)
            self._cursor.connect_event('button_press_event', self.on_cursorclick)
            while self.plot2d.toolbar.mode != '':
                # turn off zoom, pan, etc. modes.
                self.plot2d.toolbar.zoom()
        else:
            self._cursor.disconnect_events()
            self._cursor = None
            self._undo_stack.append(self.mask)
            self.plot2d.replot(keepzoom=False)

    def on_cursorclick(self, event):
        if (event.inaxes == self.plot2d.axis) and (self.plot2d.toolbar.mode == ''):
            self.mask[round(event.ydata), round(event.xdata)] ^= True
            self._cursor.disconnect_events()
            self._cursor = None
            self.plot2d.replot(keepzoom=True)
            self.on_pixelhunting_toggled(self.builder.get_object('pixelhunting_button'))

    def cleanup(self):
        super().cleanup()
        self._undo_stack = []

    def on_undo(self, button):
        try:
            self.mask = self._undo_stack.pop()
        except IndexError:
            return
        self.plot2d.set_mask(self.mask)
コード例 #4
0
class RoiSelector():
    """
    patch: matplotlib.patches.Patch
            Patch object that holds ROI information. Used to visually show ROI on image and collect
            data from ROI for analysis.
    """
    
    def __init__(self, axes, roi_type):
        self.axes = axes
        self.roi_type = roi_type
        self.patch = None
        self.lasso_switch = False
        self.verts = None
        self.title = 'N/A'
        self.annotate = None
        self.intensity = 0
        self.global_switch = False
        if self.roi_type == 'rectangle':
            self.roi = RectangleSelector(self.axes, self.onselect, drawtype='box', interactive=True)
        elif self.roi_type == 'ellipse':
            self.roi = EllipseSelector(self.axes, self.onselect, drawtype='box', interactive=True)
        else:
            self.roi = LassoSelector(self.axes, onselect=self.lasso_select)
            
    def lasso_select(self, verts):
        if self.global_switch is False:
            del self.roi
            self.roi = LassoSelector(self.axes, onselect=self.lasso_select)
            self.axes.figure.canvas.draw()
            return
        if self.lasso_switch is True:
            self.patch.remove()
        self.verts = verts
        self.p = path.Path(verts, closed=True)
        self.patch = patches.PathPatch(self.p, facecolor=(1, 0, 0, 0.3), lw=2)
        self.axes.add_patch(self.patch)
        self.lasso_switch = True
        if self.annotate != None:
            self.annotate.remove()
        self.label(self.title)
        
    def onselect(self, eclick, erelease):
        if erelease:
            if self.annotate != None:
                self.annotate.remove()
            self.label(self.title)
            
    def label(self, label):
        if self.lasso_switch is True:
            if self.verts is None:
                return
            xy = self.verts[0]
        else:
            xy = self.roi.center
        self.annotate = self.axes.annotate(label, xy=xy, 
                           bbox={'facecolor':'white', 'alpha':0.5, 'pad':10}, fontsize=10)
        self.axes.figure.canvas.draw()
            
    def visible(self, switch):
        if switch is True:
            if self.annotate != None:
                self.annotate.remove()
                self.annotate = None
            self.roi.set_visible(True)
            self.label(self.title)
        if switch is False:
            if self.annotate != None:
                self.annotate.remove()
                self.annotate = None
            self.roi.set_visible(False)
        self.axes.figure.canvas.draw()
            
    def active(self, switch):
        self.roi.set_active(switch)
        
    def draw(self, extents):
        self.roi.extents = extents
        self.roi.set_visible(True)
    
    def hide_lasso(self):
        pass
    
    def lasso_visible(self, switch):
        if self.patch is None:
            return
        self.patch.set_visible(switch)
        if switch is False:
            if self.annotate != None:
                if self.annotate.get_visible():
                    self.annotate.remove()
                    self.annotate = None
            self.axes.figure.canvas.draw()
        else:
            self.label(self.title)
        
    def remove_lasso(self):
        self.patch.remove()
        
    def sum_roi(self, data):
        intensity = 0
        if self.roi_type != 'rectangle':
            if self.roi_type == 'ellipse':
                center = self.roi.center
                width = self.roi.extents[1]-self.roi.extents[0]
                height = self.roi.extents[3] - self.roi.extents[2]
                patch = patches.Ellipse(center, width, height)
            else:
                patch = self.patch
            for yval in range(data.shape[0]):
                for xval in range(data.shape[1]):
                    if patch.get_path().contains_point((xval,yval)) or patch.contains_point((xval, yval)):
                        intensity += data[xval][yval]    
        else:
            for xval in range(int(round(self.roi.extents[0])), int(round(self.roi.extents[1]))):
                for yval in range(int(round(self.roi.extents[2])), int(round(self.roi.extents[3]))):
                    intensity += data[xval][yval]
        self.intensity = intensity
        return self.intensity
コード例 #5
0
class image_segmenter:
    """
    Manually segment an image with the lasso selector.
    """

    def __init__(
        self,
        img,
        nclasses=1,
        mask=None,
        mask_colors=None,
        mask_alpha=0.75,
        lineprops=None,
        lasso_mousebutton="left",
        pan_mousebutton="middle",
        ax=None,
        figsize=(10, 10),
        **kwargs,
    ):
        """
        Create an image segmenter. Any ``kwargs`` will be passed through to the ``imshow``
        call that displays *img*.

        Parameters
        ----------
        img : array_like
            A valid argument to imshow
        nclasses : int, default 1
        mask : arraylike, optional
            If you want to pre-seed the mask
        mask_colors : None, color, or array of colors, optional
            the colors to use for each class. Unselected regions will always be totally transparent
        mask_alpha : float, default .75
            The alpha values to use for selected regions. This will always override the alpha values
            in mask_colors if any were passed
        lineprops : dict, default: None
            lineprops passed to LassoSelector. If None the default values are:
            {"color": "black", "linewidth": 1, "alpha": 0.8}
        lasso_mousebutton : str, or int, default: "left"
            The mouse button to use for drawing the selecting lasso.
        pan_mousebutton : str, or int, default: "middle"
            The button to use for `~mpl_interactions.generic.panhandler`. One of 'left', 'middle' or
            'right', or 1, 2, 3 respectively.
        ax : `matplotlib.axes.Axes`, optional
            The axis on which to plot. If *None* a new figure will be created.
        figsize : (float, float), optional
            passed to plt.figure. Ignored if *ax* is given.
        **kwargs
            All other kwargs will passed to the imshow command for the image
        """
        # ensure mask colors is iterable and the same length as the number of classes
        # choose colors from default color cycle?

        self.mask_alpha = mask_alpha

        if mask_colors is None:
            # this will break if there are more than 10 classes
            if nclasses <= 10:
                self.mask_colors = to_rgba_array(list(TABLEAU_COLORS)[:nclasses])
            else:
                # up to 949 classes. Hopefully that is always enough....
                self.mask_colors = to_rgba_array(list(XKCD_COLORS)[:nclasses])
        else:
            self.mask_colors = to_rgba_array(np.atleast_1d(mask_colors))
            # should probably check the shape here
        self.mask_colors[:, -1] = self.mask_alpha

        self._img = np.asarray(img)

        if mask is None:
            self.mask = np.zeros(self._img.shape[:2])
            """See :doc:`/examples/image-segmentation`."""
        else:
            self.mask = mask

        self._overlay = np.zeros((*self._img.shape[:2], 4))
        self.nclasses = nclasses
        for i in range(nclasses + 1):
            idx = self.mask == i
            if i == 0:
                self._overlay[idx] = [0, 0, 0, 0]
            else:
                self._overlay[idx] = self.mask_colors[i - 1]
        if ax is not None:
            self.ax = ax
            self.fig = self.ax.figure
        else:
            with ioff():
                self.fig = figure(figsize=figsize)
                self.ax = self.fig.gca()
        self.displayed = self.ax.imshow(self._img, **kwargs)
        self._mask = self.ax.imshow(self._overlay)

        if lineprops is None:
            lineprops = {"color": "black", "linewidth": 1, "alpha": 0.8}
        useblit = False if "ipympl" in get_backend().lower() else True
        button_dict = {"left": 1, "middle": 2, "right": 3}
        if isinstance(pan_mousebutton, str):
            pan_mousebutton = button_dict[pan_mousebutton.lower()]
        if isinstance(lasso_mousebutton, str):
            lasso_mousebutton = button_dict[lasso_mousebutton.lower()]

        self.lasso = LassoSelector(
            self.ax, self._onselect, lineprops=lineprops, useblit=useblit, button=lasso_mousebutton
        )
        self.lasso.set_visible(True)

        pix_x = np.arange(self._img.shape[0])
        pix_y = np.arange(self._img.shape[1])
        xv, yv = np.meshgrid(pix_y, pix_x)
        self.pix = np.vstack((xv.flatten(), yv.flatten())).T

        self.ph = panhandler(self.fig, button=pan_mousebutton)
        self.disconnect_zoom = zoom_factory(self.ax)
        self.current_class = 1
        self.erasing = False

    def _onselect(self, verts):
        self.verts = verts
        p = Path(verts)
        self.indices = p.contains_points(self.pix, radius=0).reshape(self.mask.shape)
        if self.erasing:
            self.mask[self.indices] = 0
            self._overlay[self.indices] = [0, 0, 0, 0]
        else:
            self.mask[self.indices] = self.current_class
            self._overlay[self.indices] = self.mask_colors[self.current_class - 1]

        self._mask.set_data(self._overlay)
        self.fig.canvas.draw_idle()

    def _ipython_display_(self):
        display(self.fig.canvas)