def pr_stats(run, image_dir, label_db, connected_components_threshold): # TODO: a bunch of this can go back into one off init in a class _train_opts, model = m.restore_model(run) label_db = LabelDB(label_db_file=label_db) set_comparison = u.SetComparison() # use 4 images for debug debug_imgs = [] for idx, filename in enumerate(sorted(os.listdir(image_dir))): # load next image # TODO: this block used in various places, refactor img = np.array(Image.open(image_dir + "/" + filename)) # uint8 0->255 (H, W) img = img.astype(np.float32) img = (img / 127.5) - 1.0 # -1.0 -> 1.0 # see data.py # run through model prediction = expit(model.predict(np.expand_dims(img, 0))[0]) if len(debug_imgs) < 4: debug_imgs.append(u.side_by_side(rgb=img, bitmap=prediction)) # calc [(x,y), ...] centroids predicted_centroids = u.centroids_of_connected_components( prediction, rescale=2.0, threshold=connected_components_threshold) # compare to true labels true_centroids = label_db.get_bugs(filename) true_centroids = [(y, x) for (x, y) in true_centroids] # sigh... tp, fn, fp = set_comparison.compare_sets(true_centroids, predicted_centroids) precision, recall, f1 = set_comparison.precision_recall_f1() return { "debug_imgs": debug_imgs, "precision": precision, "recall": recall, "f1": f1 }
class LabelUI(QGraphicsView): """PyQt image viewer adapted from https://github.com/marcel-goldschen-ohm/PyQtImageViewer/blob/master/QtImageViewer.py """ # Create signals for mouse events. Note that mouse button signals # emit (x, y) coordinates but image matrices are indexed (y, x). leftMouseButtonPressed = pyqtSignal(float, float) rightMouseButtonPressed = pyqtSignal(float, float) leftMouseButtonReleased = pyqtSignal(float, float) rightMouseButtonReleased = pyqtSignal(float, float) leftMouseButtonDoubleClicked = pyqtSignal(float, float) rightMouseButtonDoubleClicked = pyqtSignal(float, float) def __init__(self, label_db_filename, img_dir): QGraphicsView.__init__(self) self.setWindowTitle(label_db_filename) if img_dir is None: img_dir = str( QFileDialog.getExistingDirectory(self, 'Select image directory')) if not os.path.exists(img_dir): raise RuntimeError(f'Provided directory {img_dir} does not exist') self.img_dir = img_dir files_list = [] # Walk through directory tree, get all files for dir_path, dir_names, filenames in os.walk(img_dir): files_list += [os.path.join(dir_path, f) for f in filenames] files_list = sorted(files_list) files_list = list( filter( lambda x: x.lower().endswith(('.png', '.jpg', '.jpeg', '.tiff', '.bmp', '.gif', '.cr2')), files_list)) self.files = files_list if len(self.files) == 0: raise RuntimeError( f'Unable to find any image files in provided directory {img_dir}' ) # Label db self.label_db = LabelDB(label_db_filename) self.label_db.create_if_required() # A lookup table from bug x,y to any labels that have been added self.x_y_to_labels = {} # { (x, y): Label, ... } # Flag to denote if bugs are being displayed or not. # While not displayed, we lock down all image navigation self.display_labels = True # Main review loop self.file_idx = 0 # Image is displayed as a QPixmap in a QGraphicsScene attached to this QGraphicsView self.scene = QGraphicsScene() self.setScene(self.scene) # Store a local handle to the scene's current image pixmap self._pixmapHandle = None # Scale image to fit inside viewport, preserving aspect ratio self.aspectRatioMode = Qt.KeepAspectRatio # Shows a scroll bar only when zoomed self.setHorizontalScrollBarPolicy(Qt.ScrollBarAsNeeded) self.setVerticalScrollBarPolicy(Qt.ScrollBarAsNeeded) # Stack of QRectF zoom boxes in scene coordinates self.zoomStack = [] # Initialize some other variables used occasionally self.tmp_x_y = [] self.click_start_pos = QPoint(0, 0) self._t_key_pressed = False self._started_tickmark_click = False self.complete = False self.display_image() self.show() self.setWindowState(Qt.WindowMaximized) def update_title(self): name = os.path.basename(self.files[self.file_idx]) num_bugs = 0 num_tickmarks = 0 num_tickmark_numbers = 0 for label in self.x_y_to_labels.values(): if isinstance(label, Bug): num_bugs += 1 elif isinstance(label, Tickmark): num_tickmarks += 1 elif isinstance(label, TickmarkNumber): num_tickmark_numbers += 1 title = f'{name} ({self.file_idx + 1} of {len(self.files)}): ' title += f'{num_bugs} bug' if num_bugs != 1: title += 's' title += f', {num_tickmarks} tickmark' if num_tickmarks != 1: title += 's' title += f', {num_tickmark_numbers} tickmark number' if num_tickmark_numbers != 1: title += 's' if self.complete: title += ' [COMPLETE]' self.setWindowTitle(title) def has_image(self): """ Returns whether or not the scene contains an image pixmap. """ return self._pixmapHandle is not None def clear_image(self): """ Removes the current image pixmap from the scene if it exists. """ if self.has_image(): self.scene.removeItem(self._pixmapHandle) self._pixmapHandle = None def pixmap(self): """ Returns the scene's current image pixmap as a QPixmap, or else None if no image exists. :rtype: QPixmap | None """ if self.has_image(): return self._pixmapHandle.pixmap() return None def set_image(self, image): """ Set the scene's current image pixmap to the input QImage or QPixmap. Raises a RuntimeError if the input image has type other than QImage or QPixmap. """ if type(image) is QPixmap: pixmap = image elif isinstance(image, QImage): pixmap = QPixmap.fromImage(image) else: raise RuntimeError( 'ImageViewer.setImage: Argument must be a QImage or QPixmap.') if self.has_image(): self._pixmapHandle.setPixmap(pixmap) else: self._pixmapHandle = self.scene.addPixmap(pixmap) self.setSceneRect(QRectF( pixmap.rect())) # Set scene size to image size self.update_viewer() def update_viewer(self): """ Show current zoom (if showing entire image, apply current aspect ratio mode). """ if not self.has_image(): return if len(self.zoomStack) and self.sceneRect().contains( self.zoomStack[-1]): self.fitInView( self.zoomStack[-1], Qt.IgnoreAspectRatio) # Show zoomed rect (ignore aspect ratio) else: # Clear the zoom stack (in case we got here because of an invalid zoom) self.zoomStack = [] # Show entire image (use current aspect ratio mode) self.fitInView(self.sceneRect(), self.aspectRatioMode) def resizeEvent(self, event): """ Maintain current zoom on resize. """ self.update_viewer() def mousePressEvent(self, event): """ Start mouse pan or zoom mode. """ scene_pos = self.mapToScene(event.pos()) self.click_start_pos = event.pos() if event.button() == Qt.LeftButton: if self._t_key_pressed: self._started_tickmark_click = True self.setDragMode(QGraphicsView.RubberBandDrag) else: self.setDragMode(QGraphicsView.ScrollHandDrag) self.leftMouseButtonPressed.emit(scene_pos.x(), scene_pos.y()) elif event.button() == Qt.RightButton: self.setDragMode(QGraphicsView.RubberBandDrag) self.rightMouseButtonPressed.emit(scene_pos.x(), scene_pos.y()) QGraphicsView.mousePressEvent(self, event) def mouseReleaseEvent(self, event): """ Stop mouse pan or zoom mode (apply zoom if valid). """ QGraphicsView.mouseReleaseEvent(self, event) scene_pos = self.mapToScene(event.pos()) movement_vector = event.pos() - self.click_start_pos click_distance = math.sqrt(movement_vector.x()**2 + movement_vector.y()**2) if event.button() == Qt.LeftButton: if self._started_tickmark_click: if click_distance < 1: self.add_tickmark_event(event) else: view_bbox = self.zoomStack[-1] if len( self.zoomStack) else self.sceneRect() selection_bbox = self.scene.selectionArea().boundingRect( ).intersected(view_bbox) self.scene.setSelectionArea( QPainterPath()) # Clear current selection area. if selection_bbox.isValid() and (selection_bbox != view_bbox): self.add_tickmark_number_event(selection_bbox) # A little hacky, assume someone releases T key to type the numbers. # If they keep it pressed down they'll be switched back to bug mode # and will have to press it again. self._t_key_pressed = False else: if click_distance < 1: self.add_bug_event(event) self.setDragMode(QGraphicsView.NoDrag) self.leftMouseButtonReleased.emit(scene_pos.x(), scene_pos.y()) elif event.button() == Qt.RightButton: view_bbox = self.zoomStack[-1] if len( self.zoomStack) else self.sceneRect() selection_bbox = self.scene.selectionArea().boundingRect( ).intersected(view_bbox) self.scene.setSelectionArea( QPainterPath()) # Clear current selection area. if selection_bbox.isValid() and (selection_bbox != view_bbox): self.zoomStack.append(selection_bbox) self.update_viewer() self.setDragMode(QGraphicsView.NoDrag) if click_distance < 1: self.remove_closest_label_event(event) self.rightMouseButtonReleased.emit(scene_pos.x(), scene_pos.y()) self._started_tickmark_click = False self.update_title() def keyPressEvent(self, event): if event.key() == Qt.Key_Right: self.display_next_image() elif event.key() == Qt.Key_Left: self.display_previous_image() elif event.key() == Qt.Key_Up: self.toggle_bugs() elif event.key() == Qt.Key_N: self.display_next_incomplete_image() elif event.key() == Qt.Key_Q: self.exit_program() elif event.key() == Qt.Key_Escape: self.zoomStack = [] # Clear zoom stack. self.update_viewer() elif event.key() == Qt.Key_T: self._t_key_pressed = True elif event.key() == Qt.Key_C: self.complete = False if self.complete else True self.update_title() def exit_program(self): self._flush_pending_x_y_to_boxes() QApplication.instance().quit() def keyReleaseEvent(self, event): if event.key() == Qt.Key_T: self._t_key_pressed = False def display_image(self): # Open image img_name = self.files[self.file_idx] img_path = os.path.join(self.img_dir, img_name) # If this is a raw file it gets special treatment if img_path.lower().endswith('.cr2'): # Read raw file raw = rawpy.imread(img_path) # Convert to PIL Image img = Image.fromarray(raw.postprocess()) else: img = Image.open(img_path) # For some reason RGB images do not like to display in the interface. # RGBA seems to work img = img.convert('RGBA') # Convert to QImage img = ImageQt(img) self.set_image(img) # Look up any existing labels in DB for this image and add them self.update_state_from_db(img_name) def update_state_from_db(self, img_name): existing_bugs = self.label_db.get_bugs(img_name) for x, y in existing_bugs: self.add_bug_at(x, y) existing_tickmarks = self.label_db.get_tickmarks(img_name) for x, y in existing_tickmarks: self.add_tickmark_at(x, y) existing_tickmark_numbers = self.label_db.get_tickmark_numbers( img_name) for x, y, w, h, val in existing_tickmark_numbers: self.add_tickmark_number_at(x, y, w, h, val) complete = self.label_db.get_complete(img_name) self.complete = complete self.update_title() def display_next_image(self): if not self.display_labels: print('ignore move to next image; labels not displayed') return self._flush_pending_x_y_to_boxes() self.file_idx += 1 if self.file_idx == len(self.files): print("Can't move to image past last image.") self.file_idx = len(self.files) - 1 self.display_image() def display_previous_image(self): if not self.display_labels: print('ignore move to previous image; labels not displayed') return self._flush_pending_x_y_to_boxes() self.file_idx -= 1 if self.file_idx < 0: print("Can't move to image previous to first image.") self.file_idx = 0 self.display_image() def display_next_incomplete_image(self): self._flush_pending_x_y_to_boxes() while True: self.file_idx += 1 if self.file_idx == len(self.files): print("Can't move to image past last image.") self.file_idx = len(self.files) - 1 break if not self.label_db.get_complete(self.files[self.file_idx]): break self.display_image() def add_bug_at(self, x, y): size = self.scene.width() // 300 rectangle_id = self.scene.addRect(x - size // 2, y - size // 2, size, size, QPen(Qt.black), QBrush(Qt.red)) self.x_y_to_labels[(x, y)] = Bug(x, y, rectangle_id) self.update_title() def add_bug_event(self, e): scene_pos = self.mapToScene(e.pos()) if not self.display_labels: print('ignore add bug; labels not displayed') return self.add_bug_at(scene_pos.x(), scene_pos.y()) def add_tickmark_at(self, x, y): size = self.scene.width() // 300 rectangle_id = self.scene.addRect(x - size // 2, y - size // 2, size, size, QPen(Qt.black), QBrush(Qt.blue)) self.x_y_to_labels[(x, y)] = Tickmark(x, y, rectangle_id) self.update_title() def add_tickmark_event(self, e): scene_pos = self.mapToScene(e.pos()) if not self.display_labels: print('ignore add tickmark; labels not displayed') return self.add_tickmark_at(scene_pos.x(), scene_pos.y()) def add_tickmark_number_at(self, x, y, width, height, val): rectangle_id = self.scene.addRect( x, y, width, height, QPen(Qt.blue, self.scene.width() // 300)) font = QFont() font_pixel_size = 1 font.setPixelSize(font_pixel_size) while QFontMetrics(font).boundingRect(str(val)).width() < width \ and QFontMetrics(font).boundingRect(str(val)).height() < height: font_pixel_size += 1 font.setPixelSize(font_pixel_size) font_pixel_size -= 1 number_canvas_id = self.scene.addText(str(val), font) number_canvas_id.setDefaultTextColor(Qt.blue) number_canvas_id.setPos(QPoint(int(x), int(y))) self.x_y_to_labels[(x, y)] = TickmarkNumber(x, y, rectangle_id, width, height, val, number_canvas_id) def add_tickmark_number_event(self, box): if not self.display_labels: print('ignore add tickmark; labels not displayed') return val, _ = QInputDialog.getInt(self, 'Input', 'Enter tickmark value:') self.add_tickmark_number_at(box.x(), box.y(), box.width(), box.height(), val) def _flush_pending_x_y_to_boxes(self): """Write labels to database and remove them from canvas """ img_name = self.files[self.file_idx] # Write to database self.label_db.set_labels(img_name, self.x_y_to_labels.values()) # Remove from canvas for label in self.x_y_to_labels.values(): self.remove_label(label) self.x_y_to_labels.clear() self.label_db.set_complete(img_name, self.complete) def toggle_bugs(self): # if self.display_labels: # # store x,y s in tmp list and delete all rectangles from canvas # self.tmp_x_y = [] # for (x, y), rectangle_id in self.x_y_to_labels.items(): # self.remove_label(rectangle_id) # self.tmp_x_y.append((x, y)) # self.x_y_to_labels = {} # self.display_labels = False # else: # labels not displayed # # restore all temp stored bugs # for x, y in self.tmp_x_y: # self.add_bug_at(x, y) # self.display_labels = True self.display_labels = not self.display_labels for item in self.scene.items(): if not isinstance(item, QGraphicsPixmapItem): item.setVisible(self.display_labels) def remove_label(self, label): canvas_id = label.canvas_id self.scene.removeItem(canvas_id) if isinstance(label, TickmarkNumber): number_canvas_id = label.number_canvas_id self.scene.removeItem(number_canvas_id) def remove_closest_label_event(self, e): scene_pos = self.mapToScene(e.pos()) if not self.display_labels: print('ignore remove label; labels not displayed') return if len(self.x_y_to_labels) == 0: return closest_point = None closest_sqr_distance = 0.0 for x, y in self.x_y_to_labels.keys(): sqr_distance = (scene_pos.x() - x)**2 + (scene_pos.y() - y)**2 if sqr_distance < closest_sqr_distance or closest_point is None: closest_point = (x, y) closest_sqr_distance = sqr_distance self.remove_label(self.x_y_to_labels.pop(closest_point)) self.update_title()
for filename in filenames: original_filename = filename filename = bnn_util.get_path_relative_to_drive(filename) drive_base_path = os.path.expanduser( '~/data/srpa226-drive/Sharing/202012 Paul/') filename = os.path.join(drive_base_path, filename) if not os.path.exists(filename): print(f'File not found, skipping: {filename}') else: img = Image.open(filename) width, height = img.size if not label_db.get_complete(original_filename): print(f'Image labeling not complete, skipping: {filename}') else: bitmap = bnn_util.xys_to_bitmap( xys=label_db.get_bugs(original_filename), height=height, width=width, rescale=opts.label_rescale) single_channel_img = bnn_util.bitmap_to_single_channel_pil_image( bitmap) print(f'Processing {filename}') img_new_filename = os.path.join(opts.image_output_dir, os.path.basename(filename)) bitmap_filename = os.path.basename( os.path.splitext(filename)[0] + '_train_bitmap_bugs.png') bitmap_filename = os.path.join(opts.label_output_dir, bitmap_filename) img.save(img_new_filename) single_channel_img.save(bitmap_filename)