Exemplo n.º 1
0
    def __init__(self, label_db_filename, img_dir, sort=True):

        # what images to review?
        # note: drop trailing / in dir name (if present)
        self.img_dir = re.sub("/$", "", img_dir)
        self.files = os.listdir(img_dir)
        if sort:
            self.files = sorted(self.files)
        else:
            random.shuffle(self.files)

        # verify all the image files are okay and the same size
        imgnames = [os.path.join(img_dir, imgname) for imgname in self.files]
        width, height = u.check_images(imgnames)
        print("%d files to review" % len(self.files))

        # label db
        self.label_db = LabelDB(label_db_filename)
        self.label_db.create_if_required()

        # TK UI
        root = tk.Tk()
        root.title(label_db_filename)
        root.bind('<Right>', self.display_next_image)
        print("RIGHT  next image")
        root.bind('<Left>', self.display_previous_image)
        print("LEFT   previous image")
        root.bind('<Up>', self.toggle_bees)
        print("UP     toggle labels")
        root.bind('N', self.display_next_unlabelled_image)
        print("N   next image with 0 labels")
        root.bind('Q', self.quit)
        print("Q   quit")
        self.canvas = tk.Canvas(root, cursor='tcross')
        self.canvas.config(width=width, height=height)
        self.canvas.bind('<Button-1>', self.add_bee_event)  # left mouse button
        self.canvas.bind('<Button-3>',
                         self.remove_closest_bee_event)  # right mouse button

        self.canvas.pack()

        # A lookup table from bee x,y to any rectangles that have been drawn
        # in case we want to remove one. the keys of this dict represent all
        # the bee x,y in current image.
        self.x_y_to_boxes = {}  # { (x, y): canvas_id, ... }

        # a flag to denote if bees are being displayed or not
        # while no displayed we lock down all img navigation
        self.bees_on = True

        # Main review loop
        self.file_idx = 0
        self.display_new_image()
        root.mainloop()
Exemplo n.º 2
0
Arquivo: test.py Projeto: yang-neu/bnn
def pr_stats(run, image_dir, label_db, connected_components_threshold):

    # TODO: a bunch of this can go back into one off init in a class

    _train_opts, model = m.restore_model(run)

    label_db = LabelDB(label_db_file=label_db)

    set_comparison = u.SetComparison()

    # use 4 images for debug
    debug_imgs = []

    for idx, filename in enumerate(sorted(os.listdir(image_dir))):
        # load next image
        # TODO: this block used in various places, refactor
        img = np.array(Image.open(image_dir + "/" +
                                  filename))  # uint8 0->255  (H, W)
        img = img.astype(np.float32)
        img = (img / 127.5) - 1.0  # -1.0 -> 1.0  # see data.py

        # run through model
        prediction = expit(model.predict(np.expand_dims(img, 0))[0])

        if len(debug_imgs) < 4:
            debug_imgs.append(u.side_by_side(rgb=img, bitmap=prediction))

        # calc [(x,y), ...] centroids
        predicted_centroids = u.centroids_of_connected_components(
            prediction, rescale=2.0, threshold=connected_components_threshold)

        # compare to true labels
        true_centroids = label_db.get_labels(filename)
        true_centroids = [(y, x) for (x, y) in true_centroids]  # sigh...
        tp, fn, fp = set_comparison.compare_sets(true_centroids,
                                                 predicted_centroids)

    precision, recall, f1 = set_comparison.precision_recall_f1()

    return {
        "debug_imgs": debug_imgs,
        "precision": precision,
        "recall": recall,
        "f1": f1
    }
Exemplo n.º 3
0
import os
import sys
import util as u

# TODO: make this multiprocess, too slow as is...

parser = argparse.ArgumentParser(
    formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--label-db',
                    type=str,
                    help='label_db to materialise bitmaps from')
parser.add_argument('--directory', type=str, help='directory to store bitmaps')
opts = parser.parse_args()
print(opts)

label_db = LabelDB(label_db_file=opts.label_db)

if not os.path.exists(opts.directory):
    os.makedirs(opts.directory)

fnames = list(label_db.imgs())
for i, fname in enumerate(fnames):
    bitmap = u.xys_to_bitmap(xys=label_db.get_labels(fname),
                             height=1024,
                             width=768,
                             rescale=0.5)
    single_channel_img = u.bitmap_to_single_channel_pil_image(bitmap)
    single_channel_img.save("%s/%s" %
                            (opts.directory, fname.replace(".jpg", ".png")))
    sys.stdout.write("%d/%d   \r" % (i, len(fnames)))
Exemplo n.º 4
0
#!/usr/bin/env python3

import argparse
from label_db import LabelDB
import numpy as np
import util as u

parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--true-db', type=str, required=True, help='true labels')
parser.add_argument('--predicted-db', type=str, required=True, help='predicted labels')
opts = parser.parse_args()
assert opts.true_db != opts.predicted_db

true_db = LabelDB(label_db_file=opts.true_db)
predicted_db = LabelDB(label_db_file=opts.predicted_db)

# iterate over predicted_db; we expect true_db to be a super set of it.
print("\t".join(["img", "#1_total", "#2_total", "ad", "#1_left", "#2_left"]))
total_TP = total_FP = total_FN = 0
for img in predicted_db.imgs():
  if not true_db.has_labels(img):
    # note: this can imply 0 labels
    raise Exception("img %s is in --predicted-db but not --true-db")

  true_labels = true_db.get_labels(img)
  predicted_labels = predicted_db.get_labels(img)
  TP, FP, FN = u.compare_sets(true_labels, predicted_labels)
  print("img", img, TP, FP, FN)

  total_TP += TP
  total_FP += FP
Exemplo n.º 5
0
opts = parser.parse_args()

# feed data through an explicit placeholder to avoid using tf.data
imgs = tf.placeholder(dtype=tf.float32, shape=(1, opts.height, opts.width, 3), name='input_imgs')

# restore model
model = model.Model(imgs,
                    is_training=False,
                    use_skip_connections=not opts.no_use_skip_connections,
                    base_filter_size=opts.base_filter_size,
                    use_batch_norm=not opts.no_use_batch_norm)
sess = tf.Session()
model.restore(sess, "ckpts/%s" % opts.run)

if opts.output_label_db:
  db = LabelDB(label_db_file=opts.output_label_db)
  db.create_if_required()
else:
  db = None

if opts.export_pngs:
  export_dir = "predict_examples/%s" % opts.run
  print("exporting prediction samples to [%s]" % export_dir)
  if not os.path.exists(export_dir):
    os.makedirs(export_dir)

# TODO: make this batched to speed it up for larger runs

imgs = os.listdir(opts.image_dir)
if opts.num is not None:
  assert opts.num > 0
Exemplo n.º 6
0
# super clumsy :/

parser = argparse.ArgumentParser(
    formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--from-db',
                    type=str,
                    required=True,
                    help='db to take entries from')
parser.add_argument('--into-db',
                    type=str,
                    required=True,
                    help='db to add entries to')
opts = parser.parse_args()

assert opts.from_db != opts.into_db

from_db = LabelDB(label_db_file=opts.from_db)
into_db = LabelDB(label_db_file=opts.into_db)

num_ignored = 0
num_added = 0
for img in from_db.imgs():
    if into_db.has_labels(img):
        print("ignore", img, "; already in into_db")
        num_ignored += 1
    else:
        into_db.set_labels(img, from_db.get_labels(img))
        num_added += 1

print("num_ignored", num_ignored, "num_added", num_added)
Exemplo n.º 7
0
    def __init__(self, label_db_filename, img_dir):
        QGraphicsView.__init__(self)
        self.setWindowTitle(label_db_filename)

        if img_dir is None:
            img_dir = str(
                QFileDialog.getExistingDirectory(self,
                                                 'Select image directory'))

        if not os.path.exists(img_dir):
            raise RuntimeError(f'Provided directory {img_dir} does not exist')

        self.img_dir = img_dir
        files_list = []
        # Walk through directory tree, get all files
        for dir_path, dir_names, filenames in os.walk(img_dir):
            files_list += [os.path.join(dir_path, f) for f in filenames]
        files_list = sorted(files_list)
        files_list = list(
            filter(
                lambda x: x.lower().endswith(('.png', '.jpg', '.jpeg', '.tiff',
                                              '.bmp', '.gif', '.cr2')),
                files_list))
        self.files = files_list

        if len(self.files) == 0:
            raise RuntimeError(
                f'Unable to find any image files in provided directory {img_dir}'
            )

        # Label db
        self.label_db = LabelDB(label_db_filename)
        self.label_db.create_if_required()

        # A lookup table from bug x,y to any labels that have been added
        self.x_y_to_labels = {}  # { (x, y): Label, ... }

        # Flag to denote if bugs are being displayed or not.
        # While not displayed, we lock down all image navigation
        self.display_labels = True

        # Main review loop
        self.file_idx = 0

        # Image is displayed as a QPixmap in a QGraphicsScene attached to this QGraphicsView
        self.scene = QGraphicsScene()
        self.setScene(self.scene)

        # Store a local handle to the scene's current image pixmap
        self._pixmapHandle = None

        # Scale image to fit inside viewport, preserving aspect ratio
        self.aspectRatioMode = Qt.KeepAspectRatio

        # Shows a scroll bar only when zoomed
        self.setHorizontalScrollBarPolicy(Qt.ScrollBarAsNeeded)
        self.setVerticalScrollBarPolicy(Qt.ScrollBarAsNeeded)

        # Stack of QRectF zoom boxes in scene coordinates
        self.zoomStack = []

        # Initialize some other variables used occasionally
        self.tmp_x_y = []
        self.click_start_pos = QPoint(0, 0)
        self._t_key_pressed = False
        self._started_tickmark_click = False
        self.complete = False

        self.display_image()
        self.show()
        self.setWindowState(Qt.WindowMaximized)