class RoiSelector(): """ patch: matplotlib.patches.Patch Patch object that holds ROI information. Used to visually show ROI on image and collect data from ROI for analysis. """ def __init__(self, axes, roi_type): self.axes = axes self.roi_type = roi_type self.patch = None self.lasso_switch = False self.verts = None self.title = 'N/A' self.annotate = None self.intensity = 0 self.global_switch = False if self.roi_type == 'rectangle': self.roi = RectangleSelector(self.axes, self.onselect, drawtype='box', interactive=True) elif self.roi_type == 'ellipse': self.roi = EllipseSelector(self.axes, self.onselect, drawtype='box', interactive=True) else: self.roi = LassoSelector(self.axes, onselect=self.lasso_select) def lasso_select(self, verts): if self.global_switch is False: del self.roi self.roi = LassoSelector(self.axes, onselect=self.lasso_select) self.axes.figure.canvas.draw() return if self.lasso_switch is True: self.patch.remove() self.verts = verts self.p = path.Path(verts, closed=True) self.patch = patches.PathPatch(self.p, facecolor=(1, 0, 0, 0.3), lw=2) self.axes.add_patch(self.patch) self.lasso_switch = True if self.annotate != None: self.annotate.remove() self.label(self.title) def onselect(self, eclick, erelease): if erelease: if self.annotate != None: self.annotate.remove() self.label(self.title) def label(self, label): if self.lasso_switch is True: if self.verts is None: return xy = self.verts[0] else: xy = self.roi.center self.annotate = self.axes.annotate(label, xy=xy, bbox={'facecolor':'white', 'alpha':0.5, 'pad':10}, fontsize=10) self.axes.figure.canvas.draw() def visible(self, switch): if switch is True: if self.annotate != None: self.annotate.remove() self.annotate = None self.roi.set_visible(True) self.label(self.title) if switch is False: if self.annotate != None: self.annotate.remove() self.annotate = None self.roi.set_visible(False) self.axes.figure.canvas.draw() def active(self, switch): self.roi.set_active(switch) def draw(self, extents): self.roi.extents = extents self.roi.set_visible(True) def hide_lasso(self): pass def lasso_visible(self, switch): if self.patch is None: return self.patch.set_visible(switch) if switch is False: if self.annotate != None: if self.annotate.get_visible(): self.annotate.remove() self.annotate = None self.axes.figure.canvas.draw() else: self.label(self.title) def remove_lasso(self): self.patch.remove() def sum_roi(self, data): intensity = 0 if self.roi_type != 'rectangle': if self.roi_type == 'ellipse': center = self.roi.center width = self.roi.extents[1]-self.roi.extents[0] height = self.roi.extents[3] - self.roi.extents[2] patch = patches.Ellipse(center, width, height) else: patch = self.patch for yval in range(data.shape[0]): for xval in range(data.shape[1]): if patch.get_path().contains_point((xval,yval)) or patch.contains_point((xval, yval)): intensity += data[xval][yval] else: for xval in range(int(round(self.roi.extents[0])), int(round(self.roi.extents[1]))): for yval in range(int(round(self.roi.extents[2])), int(round(self.roi.extents[3]))): intensity += data[xval][yval] self.intensity = intensity return self.intensity
class image_segmenter: def __init__(self, img_dir, classes, overlay_alpha=.5, figsize=(10, 10), scroll_to_zoom=True, zoom_scale=1.1): """ TODO allow for intializing with a shape instead of an image parameters ---------- img_dir : string path to directory 'images' that contains 'train/' and images are in 'train/' classes : Int or list Number of classes or a list of class names ensure_rgba : boolean whether to force the displayed image to have an alpha channel to enable transparent overlay zoom_scale : float or None How much to scale the image per scroll. If you do this I recommend using jupyterlab-sidecar in order to prevent the page from scrolling. or checking in on: https://github.com/matplotlib/ipympl/issues/222 To disable zoom set this to None. """ self.img_dir = img_dir if not path.isdir(path.join(self.img_dir, 'train')): raise ValueError( f"{self.img_dir} must exist and contain the the folder 'train'" ) self.img_dir = path.join(self.img_dir, 'train') #ensure that there is a sibling directory named masks self.mask_dir = path.join( self.img_dir.rsplit('train_imgs/', 1)[0], 'train_masks/train') # self.mask_dir = path.join(self.img_dir.rslit(, 'masks/train/') if not os.path.isdir(self.mask_dir): os.makedirs(self.mask_dir) # elif not os.path.isdir(self.mask_dir): # raise ValueError(f'{self.mask_dir} already exists and is not a folder') self.image_paths = [] for type_ in VALID_IMAGE_TYPES: self.image_paths += ( glob.glob(self.img_dir.rstrip('/') + f'/*.{type_}')) self.shape = None plt.ioff() # see https://github.com/matplotlib/matplotlib/issues/17013 self.fig = plt.figure(figsize=figsize) self.ax = self.fig.gca() lineprops = {'color': 'black', 'linewidth': 1, 'alpha': 0.8} self.lasso = LassoSelector(self.ax, self.onselect, lineprops=lineprops, button=1, useblit=False) self.lasso.set_visible(True) self.fig.canvas.mpl_connect('button_press_event', self.onclick) self.fig.canvas.mpl_connect('button_release_event', self._release) self.panhandler = panhandler(self.fig) # setup lasso stuff plt.ion() if isinstance(classes, int): classes = np.arange(classes) if len(classes) <= 10: self.colors = 'tab10' elif len(classes) <= 20: self.colors = 'tab20' else: raise ValueError( f'Currently only up to 20 classes are supported, you tried to use {len(classes)} classes' ) self.colors = np.vstack([[0, 0, 0], plt.get_cmap(self.colors)(np.arange( len(classes)))[:, :3]]) self.class_dropdown = widgets.Dropdown( options=[(str(classes[i]), i) for i in range(len(classes))], value=0, description='Class:', disabled=False, ) self.lasso_button = widgets.Button( description='lasso select', disabled=False, button_style= 'success', # 'success', 'info', 'warning', 'danger' or '' icon='mouse-pointer', # (FontAwesome names without the `fa-` prefix) ) self.flood_button = widgets.Button( description='flood fill', disabled=False, button_style='', # 'success', 'info', 'warning', 'danger' or '' icon='fill-drip', # (FontAwesome names without the `fa-` prefix) ) self.erase_check_box = widgets.Checkbox(value=False, description='Erase Mode', disabled=False, indent=False) self.reset_button = widgets.Button( description='reset', disabled=False, button_style='', # 'success', 'info', 'warning', 'danger' or '' icon='refresh', # (FontAwesome names without the `fa-` prefix) ) self.save_button = widgets.Button(description='save mask', button_style='', icon='floppy-o') self.next_button = widgets.Button(description='next image', button_style='', icon='arrow-right') self.prev_button = widgets.Button(description='previous image', button_style='', icon='arrow-left', disabled=True) self.reset_button.on_click(self.reset) self.save_button.on_click(self.save_mask) self.next_button.on_click(self._change_image_idx) self.prev_button.on_click(self._change_image_idx) def button_click(button): if button.description == 'flood fill': self.flood_button.button_style = 'success' self.lasso_button.button_style = '' self.lasso.set_active(False) else: self.flood_button.button_style = '' self.lasso_button.button_style = 'success' self.lasso.set_active(True) self.lasso_button.on_click(button_click) self.flood_button.on_click(button_click) self.overlay_alpha = overlay_alpha self.indices = None self.new_image(0) #gotta do this after creating the image, and the image needs to come after all the buttons if zoom_scale is not None: self.disconnect_scroll = zoom_factory(self.ax, base_scale=zoom_scale) def _change_image_idx(self, button): if button is self.next_button: if self.img_idx + 1 < len(self.image_paths): self.img_idx += 1 self.save_mask() self.new_image(self.img_idx) if self.img_idx == len(self.image_paths): self.next_button.disabled = True self.prev_button.disabled = False elif button is self.prev_button: if self.img_idx >= 1: self.img_idx -= 1 self.save_mask() self.new_image(self.img_idx) if self.img_idx == 0: self.prev_button.disabled = True self.next_button.disabled = False def new_image(self, img_idx): self.indices = None self.img = io.imread(self.image_paths[img_idx]) self.img_idx = img_idx img_path = self.image_paths[self.img_idx] self.ax.set_title(os.path.basename(img_path)) self.mask_path = self.mask_dir + f'/{os.path.basename(img_path)}' if self.img.shape != self.shape: self.shape = self.img.shape pix_x = np.arange(self.shape[0]) pix_y = np.arange(self.shape[1]) xv, yv = np.meshgrid(pix_y, pix_x) self.pix = np.vstack((xv.flatten(), yv.flatten())).T self.displayed = self.ax.imshow(self.img) #ensure that the _nav_stack is empty self.fig.canvas.toolbar._nav_stack.clear() #add the initial view to the stack so that the home button works. self.fig.canvas.toolbar.push_current() if os.path.exists(self.mask_path): self.class_mask = io.imread(self.mask_path) else: self.class_mask = np.zeros([self.shape[0], self.shape[1]], dtype=np.uint8) else: self.displayed.set_data(self.img) if os.path.exists(self.mask_path): self.class_mask = io.imread(self.mask_path) # should probs check that the first two dimensions are the same as the img else: self.class_mask[:, :] = 0 self.fig.canvas.toolbar.home() self.updateArray() def _release(self, event): self.panhandler.release(event) def reset(self, *args): self.displayed.set_data(self.img) self.class_mask[:, :] = -1 self.fig.canvas.draw() def onclick(self, event): """ handle clicking to remove already added stuff """ if event.button == 1: if event.xdata is not None and not self.lasso.active: # transpose x and y bc imshow transposes self.indices = flood( self.class_mask, (np.int(event.ydata), np.int(event.xdata))) self.updateArray() elif event.button == 3: self.panhandler.press(event) def updateArray(self): array = self.displayed.get_array().data if self.erase_check_box.value: if self.indices is not None: self.class_mask[self.indices] = 0 array[self.indices] = self.img[self.indices] elif self.indices is not None: self.class_mask[self.indices] = self.class_dropdown.value + 1 # https://en.wikipedia.org/wiki/Alpha_compositing#Straight_versus_premultiplied c_overlay = self.colors[self.class_mask[ self.indices]] * 255 * self.overlay_alpha array[self.indices] = (c_overlay + self.img[self.indices] * (1 - self.overlay_alpha)) else: # new image and we found a class mask # so redraw entire array where class != 0 idx = self.class_mask != 0 c_overlay = self.colors[ self.class_mask[idx]] * 255 * self.overlay_alpha array[idx] = (c_overlay + self.img[idx] * (1 - self.overlay_alpha)) self.displayed.set_data(array) def onselect(self, verts): self.verts = verts p = Path(verts) self.indices = p.contains_points(self.pix, radius=0).reshape(450, 540) self.updateArray() self.fig.canvas.draw_idle() def render(self): layers = [widgets.HBox([self.lasso_button, self.flood_button])] layers.append( widgets.HBox( [self.reset_button, self.class_dropdown, self.erase_check_box])) layers.append(self.fig.canvas) layers.append( widgets.HBox( [self.save_button, self.prev_button, self.next_button])) return widgets.VBox(layers) def save_mask(self, save_if_no_nonzero=False): """ save_if_no_nonzero : boolean Whether to save if class_mask only contains 0s """ if (save_if_no_nonzero or np.any(self.class_mask != 0)): if os.path.splitext(self.mask_path)[1] in ['jpg', 'jpeg']: io.imsave(self.mask_path, self.class_mask, check_contrast=False, quality=100) else: io.imsave(self.mask_path, self.class_mask, check_contrast=False) def _ipython_display_(self): display(self.render())
class MaskEditor(ToolWindow, DoubleFileChooserDialog): def __init__(self, *args, **kwargs): self.mask = None self._undo_stack = [] self._im = None self._selector = None self._cursor = None self.exposureloader = None self.plot2d = None ToolWindow.__init__(self, *args, **kwargs) DoubleFileChooserDialog.__init__( self, self.widget, 'Open mask file...', 'Save mask file...', [('Mask files', '*.mat'), ('All files', '*')], self.instrument.config['path']['directories']['mask'], os.path.abspath(self.instrument.config['path']['directories']['mask']), ) def init_gui(self, *args, **kwargs): self.exposureloader = ExposureLoader(self.instrument) self.builder.get_object('loadexposure_expander').add(self.exposureloader) self.exposureloader.connect('open', self.on_loadexposure) self.plot2d = PlotImageWidget() self.builder.get_object('plotbox').pack_start(self.plot2d.widget, True, True, 0) self.builder.get_object('toolbar').set_sensitive(False) def on_loadexposure(self, exposureloader: ExposureLoader, im: Exposure): if self.mask is None: self.mask = im.mask self._im = im self.plot2d.set_image(im.intensity) self.plot2d.set_mask(self.mask) self.builder.get_object('toolbar').set_sensitive(True) def on_new(self, button): if self._im is None or self.mask is None: return False self.mask = np.ones_like(self.mask) self.plot2d.set_mask(self.mask) self.set_last_filename(None) def on_open(self, button): filename = self.get_open_filename() if filename is not None: mask = loadmat(filename) self.mask = mask[[k for k in mask.keys() if not k.startswith('__')][0]] self.plot2d.set_mask(self.mask) def on_save(self, button): filename = self.get_last_filename() if filename is None: return self.on_saveas(button) maskname = os.path.splitext(os.path.split(filename)[1])[0] savemat(filename, {maskname: self.mask}) def on_saveas(self, button): filename = self.get_save_filename(None) if filename is not None: self.on_save(button) def suggest_filename(self): return 'mask_dist_{0.year:d}{0.month:02d}{0.day:02d}.mat'.format(datetime.date.today()) def on_selectcircle_toggled(self, button): if button.get_active(): self.set_sensitive(False, 'Ellipse selection not ready', ['new_button', 'save_button', 'saveas_button', 'open_button', 'undo_button', 'selectrectangle_button', 'selectpolygon_button', 'pixelhunting_button', 'loadexposure_expander', 'close_button', self.plot2d.toolbar, self.plot2d.settings_expander]) while self.plot2d.toolbar.mode != '': # turn off zoom, pan, etc. modes. self.plot2d.toolbar.zoom() self._selector = EllipseSelector(self.plot2d.axis, self.on_ellipse_selected, rectprops={'facecolor': 'white', 'edgecolor': 'none', 'alpha': 0.7, 'fill': True, 'zorder': 10}, button=[1, ], interactive=False, lineprops={'zorder': 10}) self._selector.state.add('square') self._selector.state.add('center') else: assert isinstance(self._selector, EllipseSelector) self._selector.set_active(False) self._selector.set_visible(False) self._selector = None self.plot2d.replot(keepzoom=False) self.set_sensitive(True) def on_ellipse_selected(self, pos1, pos2): # pos1 and pos2 are mouse button press and release events, with xdata and ydata carrying # the two opposite corners of the bounding box of the circle. These are NOT the exact # button presses and releases! row = np.arange(self.mask.shape[0])[:, np.newaxis] column = np.arange(self.mask.shape[1])[np.newaxis, :] row0 = 0.5 * (pos1.ydata + pos2.ydata) col0 = 0.5 * (pos1.xdata + pos2.xdata) r2 = ((pos2.xdata - pos1.xdata) ** 2 + (pos2.ydata - pos1.ydata) ** 2) / 8 tobemasked = (row - row0) ** 2 + (column - col0) ** 2 <= r2 self._undo_stack.append(self.mask) if self.builder.get_object('mask_button').get_active(): self.mask &= ~tobemasked elif self.builder.get_object('unmask_button').get_active(): self.mask |= tobemasked elif self.builder.get_object('invertmask_button').get_active(): self.mask[tobemasked] = ~self.mask[tobemasked] else: pass self.builder.get_object('selectcircle_button').set_active(False) self.plot2d.set_mask(self.mask) def on_selectrectangle_toggled(self, button): if button.get_active(): self.set_sensitive(False, 'Rectangle selection not ready', ['new_button', 'save_button', 'saveas_button', 'open_button', 'undo_button', 'selectcircle_button', 'selectpolygon_button', 'pixelhunting_button', 'loadexposure_expander', 'close_button', self.plot2d.toolbar, self.plot2d.settings_expander]) while self.plot2d.toolbar.mode != '': # turn off zoom, pan, etc. modes. self.plot2d.toolbar.zoom() self._selector = RectangleSelector(self.plot2d.axis, self.on_rectangle_selected, rectprops={'facecolor': 'white', 'edgecolor': 'none', 'alpha': 0.7, 'fill': True, 'zorder': 10}, button=[1, ], interactive=False, lineprops={'zorder': 10}) else: self._selector.set_active(False) self._selector.set_visible(False) self._selector = None self.plot2d.replot(keepzoom=False) self.set_sensitive(True) def on_rectangle_selected(self, pos1, pos2): # pos1 and pos2 are mouse button press and release events, with xdata and ydata # carrying the two opposite corners of the bounding box of the rectangle. These # are NOT the exact button presses and releases! row = np.arange(self.mask.shape[0])[:, np.newaxis] column = np.arange(self.mask.shape[1])[np.newaxis, :] tobemasked = ((row >= min(pos1.ydata, pos2.ydata)) & (row <= max(pos1.ydata, pos2.ydata)) & (column >= min(pos1.xdata, pos2.xdata)) & (column <= max(pos1.xdata, pos2.xdata))) self._undo_stack.append(self.mask) if self.builder.get_object('mask_button').get_active(): self.mask = self.mask & (~tobemasked) elif self.builder.get_object('unmask_button').get_active(): self.mask = self.mask | tobemasked elif self.builder.get_object('invertmask_button').get_active(): self.mask[tobemasked] = ~self.mask[tobemasked] else: pass self.builder.get_object('selectrectangle_button').set_active(False) self.plot2d.set_mask(self.mask) def on_selectpolygon_toggled(self, button): if button.get_active(): self.set_sensitive(False, 'Polygon selection not ready', ['new_button', 'save_button', 'saveas_button', 'open_button', 'undo_button', 'selectrectangle_button', 'selectcircle_button', 'pixelhunting_button', 'loadexposure_expander', 'close_button', self.plot2d.toolbar, self.plot2d.settings_expander]) while self.plot2d.toolbar.mode != '': # turn off zoom, pan, etc. modes. self.plot2d.toolbar.zoom() self._selector = LassoSelector(self.plot2d.axis, self.on_polygon_selected, lineprops={'color': 'white', 'zorder': 10}, button=[1, ], ) else: self._selector.set_active(False) self._selector.set_visible(False) self._selector = None self.plot2d.replot(keepzoom=False) self.set_sensitive(True) def on_polygon_selected(self, vertices): path = Path(vertices) col, row = np.meshgrid(np.arange(self.mask.shape[1]), np.arange(self.mask.shape[0])) points = np.vstack((col.flatten(), row.flatten())).T tobemasked = path.contains_points(points).reshape(self.mask.shape) self._undo_stack.append(self.mask) if self.builder.get_object('mask_button').get_active(): self.mask = self.mask & (~tobemasked) elif self.builder.get_object('unmask_button').get_active(): self.mask = self.mask | tobemasked elif self.builder.get_object('invertmask_button').get_active(): self.mask[tobemasked] = ~self.mask[tobemasked] else: pass self.plot2d.set_mask(self.mask) self.builder.get_object('selectpolygon_button').set_active(False) def on_mask_toggled(self, button): pass def on_unmask_toggled(self, button): pass def on_invertmask_toggled(self, button): pass def on_pixelhunting_toggled(self, button): if button.get_active(): self._cursor = Cursor(self.plot2d.axis, useblit=False, color='white', lw=1) self._cursor.connect_event('button_press_event', self.on_cursorclick) while self.plot2d.toolbar.mode != '': # turn off zoom, pan, etc. modes. self.plot2d.toolbar.zoom() else: self._cursor.disconnect_events() self._cursor = None self._undo_stack.append(self.mask) self.plot2d.replot(keepzoom=False) def on_cursorclick(self, event): if (event.inaxes == self.plot2d.axis) and (self.plot2d.toolbar.mode == ''): self.mask[round(event.ydata), round(event.xdata)] ^= True self._cursor.disconnect_events() self._cursor = None self.plot2d.replot(keepzoom=True) self.on_pixelhunting_toggled(self.builder.get_object('pixelhunting_button')) def cleanup(self): super().cleanup() self._undo_stack = [] def on_undo(self, button): try: self.mask = self._undo_stack.pop() except IndexError: return self.plot2d.set_mask(self.mask)