class LabelUI(): def __init__(self, label_db_filename, img_dir, sort=True): # what images to review? # note: drop trailing / in dir name (if present) self.img_dir = re.sub("/$", "", img_dir) self.files = os.listdir(img_dir) if sort: self.files = sorted(self.files) else: random.shuffle(self.files) print("%d files to review" % len(self.files)) # label db self.label_db = LabelDB(label_db_filename) self.label_db.create_if_required() # TK UI root = tk.Tk() root.title(label_db_filename) root.bind('<Right>', self.display_next_image) print("RIGHT next image") root.bind('<Left>', self.display_previous_image) print("LEFT previous image") root.bind('<Up>', self.toggle_bees) print("UP toggle labels") root.bind('N', self.display_next_unlabelled_image) print("N next image with 0 labels") self.canvas = tk.Canvas(root, cursor='tcross') self.canvas.config(width=768, height=1024) self.canvas.bind('<Button-1>', self.add_bee_event) # left mouse button self.canvas.bind('<Button-3>', self.remove_closest_bee_event) # right mouse button self.canvas.pack() # A lookup table from bee x,y to any rectangles that have been drawn # in case we want to remove one. the keys of this dict represent all # the bee x,y in current image. self.x_y_to_boxes = {} # { (x, y): canvas_id, ... } # a flag to denote if bees are being displayed or not # while no displayed we lock down all img navigation self.bees_on = True # Main review loop self.file_idx = 0 self.display_new_image() root.mainloop() def add_bee_event(self, e): if not self.bees_on: print("ignore add bee; bees not on") return self.add_bee_at(e.x, e.y) def add_bee_at(self, x, y): rectangle_id = self.canvas.create_rectangle(x - 2, y - 2, x + 2, y + 2, fill='red') self.x_y_to_boxes[(x, y)] = rectangle_id def remove_bee(self, rectangle_id): self.canvas.delete(rectangle_id) def toggle_bees(self, e): if self.bees_on: # store x,y s in tmp list and delete all rectangles from canvas self.tmp_x_y = [] for (x, y), rectangle_id in self.x_y_to_boxes.items(): self.remove_bee(rectangle_id) self.tmp_x_y.append((x, y)) self.x_y_to_boxes = {} self.bees_on = False else: # bees not on # restore all temp stored bees for x, y in self.tmp_x_y: self.add_bee_at(x, y) self.bees_on = True def remove_closest_bee_event(self, e): if not self.bees_on: print("ignore remove bee; bees not on") return if len(self.x_y_to_boxes) == 0: return closest_point = None closest_sqr_distance = 0.0 for x, y in self.x_y_to_boxes.keys(): sqr_distance = (e.x - x)**2 + (e.y - y)**2 if sqr_distance < closest_sqr_distance or closest_point is None: closest_point = (x, y) closest_sqr_distance = sqr_distance self.remove_bee(self.x_y_to_boxes.pop(closest_point)) def display_next_image(self, e=None): if not self.bees_on: print("ignore move to next image; bees not on") return self._flush_pending_x_y_to_boxes() self.file_idx += 1 if self.file_idx == len(self.files): print("Can't move to image past last image.") self.file_idx = len(self.files) - 1 self.display_new_image() def display_next_unlabelled_image(self, e=None): self._flush_pending_x_y_to_boxes() while True: self.file_idx += 1 if self.file_idx == len(self.files): print("Can't move to image past last image.") self.file_idx = len(self.files) - 1 break if not self.label_db.has_labels(self.files[self.file_idx]): break self.display_new_image() def display_previous_image(self, e=None): if not self.bees_on: print("ignore move to previous image; bees not on") return self._flush_pending_x_y_to_boxes() self.file_idx -= 1 if self.file_idx < 0: print("Can't move to image previous to first image.") self.file_idx = 0 self.display_new_image() def _flush_pending_x_y_to_boxes(self): # Flush existing points. img_name = self.files[self.file_idx] if len(self.x_y_to_boxes) > 0: self.label_db.set_labels(img_name, self.x_y_to_boxes.keys()) self.x_y_to_boxes.clear() def display_new_image(self): img_name = self.files[self.file_idx] # Display image (with filename added) img = Image.open(self.img_dir + "/" + img_name) canvas = ImageDraw.Draw(img) canvas.text((0, 0), img_name, fill='black') self.tk_img = ImageTk.PhotoImage(img) self.canvas.create_image(0, 0, image=self.tk_img, anchor=tk.NW) # Look up any existing bees in DB for this image. existing_labels = self.label_db.get_labels(img_name) for x, y in existing_labels: self.add_bee_at(x, y)
# feed data through an explicit placeholder to avoid using tf.data imgs = tf.placeholder(dtype=tf.float32, shape=(1, opts.height, opts.width, 3), name='input_imgs') # restore model model = model.Model(imgs, is_training=False, use_skip_connections=not opts.no_use_skip_connections, base_filter_size=opts.base_filter_size, use_batch_norm=not opts.no_use_batch_norm) sess = tf.Session() model.restore(sess, "ckpts/%s" % opts.run) if opts.output_label_db: db = LabelDB(label_db_file=opts.output_label_db) db.create_if_required() else: db = None if opts.export_pngs: export_dir = "predict_examples/%s" % opts.run print("exporting prediction samples to [%s]" % export_dir) if not os.path.exists(export_dir): os.makedirs(export_dir) # TODO: make this batched to speed it up for larger runs imgs = os.listdir(opts.image_dir) if opts.num is not None: assert opts.num > 0 imgs = random.sample(imgs, opts.num)
class LabelUI(QGraphicsView): """PyQt image viewer adapted from https://github.com/marcel-goldschen-ohm/PyQtImageViewer/blob/master/QtImageViewer.py """ # Create signals for mouse events. Note that mouse button signals # emit (x, y) coordinates but image matrices are indexed (y, x). leftMouseButtonPressed = pyqtSignal(float, float) rightMouseButtonPressed = pyqtSignal(float, float) leftMouseButtonReleased = pyqtSignal(float, float) rightMouseButtonReleased = pyqtSignal(float, float) leftMouseButtonDoubleClicked = pyqtSignal(float, float) rightMouseButtonDoubleClicked = pyqtSignal(float, float) def __init__(self, label_db_filename, img_dir): QGraphicsView.__init__(self) self.setWindowTitle(label_db_filename) if img_dir is None: img_dir = str( QFileDialog.getExistingDirectory(self, 'Select image directory')) if not os.path.exists(img_dir): raise RuntimeError(f'Provided directory {img_dir} does not exist') self.img_dir = img_dir files_list = [] # Walk through directory tree, get all files for dir_path, dir_names, filenames in os.walk(img_dir): files_list += [os.path.join(dir_path, f) for f in filenames] files_list = sorted(files_list) files_list = list( filter( lambda x: x.lower().endswith(('.png', '.jpg', '.jpeg', '.tiff', '.bmp', '.gif', '.cr2')), files_list)) self.files = files_list if len(self.files) == 0: raise RuntimeError( f'Unable to find any image files in provided directory {img_dir}' ) # Label db self.label_db = LabelDB(label_db_filename) self.label_db.create_if_required() # A lookup table from bug x,y to any labels that have been added self.x_y_to_labels = {} # { (x, y): Label, ... } # Flag to denote if bugs are being displayed or not. # While not displayed, we lock down all image navigation self.display_labels = True # Main review loop self.file_idx = 0 # Image is displayed as a QPixmap in a QGraphicsScene attached to this QGraphicsView self.scene = QGraphicsScene() self.setScene(self.scene) # Store a local handle to the scene's current image pixmap self._pixmapHandle = None # Scale image to fit inside viewport, preserving aspect ratio self.aspectRatioMode = Qt.KeepAspectRatio # Shows a scroll bar only when zoomed self.setHorizontalScrollBarPolicy(Qt.ScrollBarAsNeeded) self.setVerticalScrollBarPolicy(Qt.ScrollBarAsNeeded) # Stack of QRectF zoom boxes in scene coordinates self.zoomStack = [] # Initialize some other variables used occasionally self.tmp_x_y = [] self.click_start_pos = QPoint(0, 0) self._t_key_pressed = False self._started_tickmark_click = False self.complete = False self.display_image() self.show() self.setWindowState(Qt.WindowMaximized) def update_title(self): name = os.path.basename(self.files[self.file_idx]) num_bugs = 0 num_tickmarks = 0 num_tickmark_numbers = 0 for label in self.x_y_to_labels.values(): if isinstance(label, Bug): num_bugs += 1 elif isinstance(label, Tickmark): num_tickmarks += 1 elif isinstance(label, TickmarkNumber): num_tickmark_numbers += 1 title = f'{name} ({self.file_idx + 1} of {len(self.files)}): ' title += f'{num_bugs} bug' if num_bugs != 1: title += 's' title += f', {num_tickmarks} tickmark' if num_tickmarks != 1: title += 's' title += f', {num_tickmark_numbers} tickmark number' if num_tickmark_numbers != 1: title += 's' if self.complete: title += ' [COMPLETE]' self.setWindowTitle(title) def has_image(self): """ Returns whether or not the scene contains an image pixmap. """ return self._pixmapHandle is not None def clear_image(self): """ Removes the current image pixmap from the scene if it exists. """ if self.has_image(): self.scene.removeItem(self._pixmapHandle) self._pixmapHandle = None def pixmap(self): """ Returns the scene's current image pixmap as a QPixmap, or else None if no image exists. :rtype: QPixmap | None """ if self.has_image(): return self._pixmapHandle.pixmap() return None def set_image(self, image): """ Set the scene's current image pixmap to the input QImage or QPixmap. Raises a RuntimeError if the input image has type other than QImage or QPixmap. """ if type(image) is QPixmap: pixmap = image elif isinstance(image, QImage): pixmap = QPixmap.fromImage(image) else: raise RuntimeError( 'ImageViewer.setImage: Argument must be a QImage or QPixmap.') if self.has_image(): self._pixmapHandle.setPixmap(pixmap) else: self._pixmapHandle = self.scene.addPixmap(pixmap) self.setSceneRect(QRectF( pixmap.rect())) # Set scene size to image size self.update_viewer() def update_viewer(self): """ Show current zoom (if showing entire image, apply current aspect ratio mode). """ if not self.has_image(): return if len(self.zoomStack) and self.sceneRect().contains( self.zoomStack[-1]): self.fitInView( self.zoomStack[-1], Qt.IgnoreAspectRatio) # Show zoomed rect (ignore aspect ratio) else: # Clear the zoom stack (in case we got here because of an invalid zoom) self.zoomStack = [] # Show entire image (use current aspect ratio mode) self.fitInView(self.sceneRect(), self.aspectRatioMode) def resizeEvent(self, event): """ Maintain current zoom on resize. """ self.update_viewer() def mousePressEvent(self, event): """ Start mouse pan or zoom mode. """ scene_pos = self.mapToScene(event.pos()) self.click_start_pos = event.pos() if event.button() == Qt.LeftButton: if self._t_key_pressed: self._started_tickmark_click = True self.setDragMode(QGraphicsView.RubberBandDrag) else: self.setDragMode(QGraphicsView.ScrollHandDrag) self.leftMouseButtonPressed.emit(scene_pos.x(), scene_pos.y()) elif event.button() == Qt.RightButton: self.setDragMode(QGraphicsView.RubberBandDrag) self.rightMouseButtonPressed.emit(scene_pos.x(), scene_pos.y()) QGraphicsView.mousePressEvent(self, event) def mouseReleaseEvent(self, event): """ Stop mouse pan or zoom mode (apply zoom if valid). """ QGraphicsView.mouseReleaseEvent(self, event) scene_pos = self.mapToScene(event.pos()) movement_vector = event.pos() - self.click_start_pos click_distance = math.sqrt(movement_vector.x()**2 + movement_vector.y()**2) if event.button() == Qt.LeftButton: if self._started_tickmark_click: if click_distance < 1: self.add_tickmark_event(event) else: view_bbox = self.zoomStack[-1] if len( self.zoomStack) else self.sceneRect() selection_bbox = self.scene.selectionArea().boundingRect( ).intersected(view_bbox) self.scene.setSelectionArea( QPainterPath()) # Clear current selection area. if selection_bbox.isValid() and (selection_bbox != view_bbox): self.add_tickmark_number_event(selection_bbox) # A little hacky, assume someone releases T key to type the numbers. # If they keep it pressed down they'll be switched back to bug mode # and will have to press it again. self._t_key_pressed = False else: if click_distance < 1: self.add_bug_event(event) self.setDragMode(QGraphicsView.NoDrag) self.leftMouseButtonReleased.emit(scene_pos.x(), scene_pos.y()) elif event.button() == Qt.RightButton: view_bbox = self.zoomStack[-1] if len( self.zoomStack) else self.sceneRect() selection_bbox = self.scene.selectionArea().boundingRect( ).intersected(view_bbox) self.scene.setSelectionArea( QPainterPath()) # Clear current selection area. if selection_bbox.isValid() and (selection_bbox != view_bbox): self.zoomStack.append(selection_bbox) self.update_viewer() self.setDragMode(QGraphicsView.NoDrag) if click_distance < 1: self.remove_closest_label_event(event) self.rightMouseButtonReleased.emit(scene_pos.x(), scene_pos.y()) self._started_tickmark_click = False self.update_title() def keyPressEvent(self, event): if event.key() == Qt.Key_Right: self.display_next_image() elif event.key() == Qt.Key_Left: self.display_previous_image() elif event.key() == Qt.Key_Up: self.toggle_bugs() elif event.key() == Qt.Key_N: self.display_next_incomplete_image() elif event.key() == Qt.Key_Q: self.exit_program() elif event.key() == Qt.Key_Escape: self.zoomStack = [] # Clear zoom stack. self.update_viewer() elif event.key() == Qt.Key_T: self._t_key_pressed = True elif event.key() == Qt.Key_C: self.complete = False if self.complete else True self.update_title() def exit_program(self): self._flush_pending_x_y_to_boxes() QApplication.instance().quit() def keyReleaseEvent(self, event): if event.key() == Qt.Key_T: self._t_key_pressed = False def display_image(self): # Open image img_name = self.files[self.file_idx] img_path = os.path.join(self.img_dir, img_name) # If this is a raw file it gets special treatment if img_path.lower().endswith('.cr2'): # Read raw file raw = rawpy.imread(img_path) # Convert to PIL Image img = Image.fromarray(raw.postprocess()) else: img = Image.open(img_path) # For some reason RGB images do not like to display in the interface. # RGBA seems to work img = img.convert('RGBA') # Convert to QImage img = ImageQt(img) self.set_image(img) # Look up any existing labels in DB for this image and add them self.update_state_from_db(img_name) def update_state_from_db(self, img_name): existing_bugs = self.label_db.get_bugs(img_name) for x, y in existing_bugs: self.add_bug_at(x, y) existing_tickmarks = self.label_db.get_tickmarks(img_name) for x, y in existing_tickmarks: self.add_tickmark_at(x, y) existing_tickmark_numbers = self.label_db.get_tickmark_numbers( img_name) for x, y, w, h, val in existing_tickmark_numbers: self.add_tickmark_number_at(x, y, w, h, val) complete = self.label_db.get_complete(img_name) self.complete = complete self.update_title() def display_next_image(self): if not self.display_labels: print('ignore move to next image; labels not displayed') return self._flush_pending_x_y_to_boxes() self.file_idx += 1 if self.file_idx == len(self.files): print("Can't move to image past last image.") self.file_idx = len(self.files) - 1 self.display_image() def display_previous_image(self): if not self.display_labels: print('ignore move to previous image; labels not displayed') return self._flush_pending_x_y_to_boxes() self.file_idx -= 1 if self.file_idx < 0: print("Can't move to image previous to first image.") self.file_idx = 0 self.display_image() def display_next_incomplete_image(self): self._flush_pending_x_y_to_boxes() while True: self.file_idx += 1 if self.file_idx == len(self.files): print("Can't move to image past last image.") self.file_idx = len(self.files) - 1 break if not self.label_db.get_complete(self.files[self.file_idx]): break self.display_image() def add_bug_at(self, x, y): size = self.scene.width() // 300 rectangle_id = self.scene.addRect(x - size // 2, y - size // 2, size, size, QPen(Qt.black), QBrush(Qt.red)) self.x_y_to_labels[(x, y)] = Bug(x, y, rectangle_id) self.update_title() def add_bug_event(self, e): scene_pos = self.mapToScene(e.pos()) if not self.display_labels: print('ignore add bug; labels not displayed') return self.add_bug_at(scene_pos.x(), scene_pos.y()) def add_tickmark_at(self, x, y): size = self.scene.width() // 300 rectangle_id = self.scene.addRect(x - size // 2, y - size // 2, size, size, QPen(Qt.black), QBrush(Qt.blue)) self.x_y_to_labels[(x, y)] = Tickmark(x, y, rectangle_id) self.update_title() def add_tickmark_event(self, e): scene_pos = self.mapToScene(e.pos()) if not self.display_labels: print('ignore add tickmark; labels not displayed') return self.add_tickmark_at(scene_pos.x(), scene_pos.y()) def add_tickmark_number_at(self, x, y, width, height, val): rectangle_id = self.scene.addRect( x, y, width, height, QPen(Qt.blue, self.scene.width() // 300)) font = QFont() font_pixel_size = 1 font.setPixelSize(font_pixel_size) while QFontMetrics(font).boundingRect(str(val)).width() < width \ and QFontMetrics(font).boundingRect(str(val)).height() < height: font_pixel_size += 1 font.setPixelSize(font_pixel_size) font_pixel_size -= 1 number_canvas_id = self.scene.addText(str(val), font) number_canvas_id.setDefaultTextColor(Qt.blue) number_canvas_id.setPos(QPoint(int(x), int(y))) self.x_y_to_labels[(x, y)] = TickmarkNumber(x, y, rectangle_id, width, height, val, number_canvas_id) def add_tickmark_number_event(self, box): if not self.display_labels: print('ignore add tickmark; labels not displayed') return val, _ = QInputDialog.getInt(self, 'Input', 'Enter tickmark value:') self.add_tickmark_number_at(box.x(), box.y(), box.width(), box.height(), val) def _flush_pending_x_y_to_boxes(self): """Write labels to database and remove them from canvas """ img_name = self.files[self.file_idx] # Write to database self.label_db.set_labels(img_name, self.x_y_to_labels.values()) # Remove from canvas for label in self.x_y_to_labels.values(): self.remove_label(label) self.x_y_to_labels.clear() self.label_db.set_complete(img_name, self.complete) def toggle_bugs(self): # if self.display_labels: # # store x,y s in tmp list and delete all rectangles from canvas # self.tmp_x_y = [] # for (x, y), rectangle_id in self.x_y_to_labels.items(): # self.remove_label(rectangle_id) # self.tmp_x_y.append((x, y)) # self.x_y_to_labels = {} # self.display_labels = False # else: # labels not displayed # # restore all temp stored bugs # for x, y in self.tmp_x_y: # self.add_bug_at(x, y) # self.display_labels = True self.display_labels = not self.display_labels for item in self.scene.items(): if not isinstance(item, QGraphicsPixmapItem): item.setVisible(self.display_labels) def remove_label(self, label): canvas_id = label.canvas_id self.scene.removeItem(canvas_id) if isinstance(label, TickmarkNumber): number_canvas_id = label.number_canvas_id self.scene.removeItem(number_canvas_id) def remove_closest_label_event(self, e): scene_pos = self.mapToScene(e.pos()) if not self.display_labels: print('ignore remove label; labels not displayed') return if len(self.x_y_to_labels) == 0: return closest_point = None closest_sqr_distance = 0.0 for x, y in self.x_y_to_labels.keys(): sqr_distance = (scene_pos.x() - x)**2 + (scene_pos.y() - y)**2 if sqr_distance < closest_sqr_distance or closest_point is None: closest_point = (x, y) closest_sqr_distance = sqr_distance self.remove_label(self.x_y_to_labels.pop(closest_point)) self.update_title()