Esempio n. 1
0
File: test.py Progetto: yang-neu/bnn
def pr_stats(run, image_dir, label_db, connected_components_threshold):

    # TODO: a bunch of this can go back into one off init in a class

    _train_opts, model = m.restore_model(run)

    label_db = LabelDB(label_db_file=label_db)

    set_comparison = u.SetComparison()

    # use 4 images for debug
    debug_imgs = []

    for idx, filename in enumerate(sorted(os.listdir(image_dir))):
        # load next image
        # TODO: this block used in various places, refactor
        img = np.array(Image.open(image_dir + "/" +
                                  filename))  # uint8 0->255  (H, W)
        img = img.astype(np.float32)
        img = (img / 127.5) - 1.0  # -1.0 -> 1.0  # see data.py

        # run through model
        prediction = expit(model.predict(np.expand_dims(img, 0))[0])

        if len(debug_imgs) < 4:
            debug_imgs.append(u.side_by_side(rgb=img, bitmap=prediction))

        # calc [(x,y), ...] centroids
        predicted_centroids = u.centroids_of_connected_components(
            prediction, rescale=2.0, threshold=connected_components_threshold)

        # compare to true labels
        true_centroids = label_db.get_labels(filename)
        true_centroids = [(y, x) for (x, y) in true_centroids]  # sigh...
        tp, fn, fp = set_comparison.compare_sets(true_centroids,
                                                 predicted_centroids)

    precision, recall, f1 = set_comparison.precision_recall_f1()

    return {
        "debug_imgs": debug_imgs,
        "precision": precision,
        "recall": recall,
        "f1": f1
    }
Esempio n. 2
0
import os
import sys
import util as u

# TODO: make this multiprocess, too slow as is...

parser = argparse.ArgumentParser(
    formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--label-db',
                    type=str,
                    help='label_db to materialise bitmaps from')
parser.add_argument('--directory', type=str, help='directory to store bitmaps')
opts = parser.parse_args()
print(opts)

label_db = LabelDB(label_db_file=opts.label_db)

if not os.path.exists(opts.directory):
    os.makedirs(opts.directory)

fnames = list(label_db.imgs())
for i, fname in enumerate(fnames):
    bitmap = u.xys_to_bitmap(xys=label_db.get_labels(fname),
                             height=1024,
                             width=768,
                             rescale=0.5)
    single_channel_img = u.bitmap_to_single_channel_pil_image(bitmap)
    single_channel_img.save("%s/%s" %
                            (opts.directory, fname.replace(".jpg", ".png")))
    sys.stdout.write("%d/%d   \r" % (i, len(fnames)))
Esempio n. 3
0
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--true-db', type=str, required=True, help='true labels')
parser.add_argument('--predicted-db', type=str, required=True, help='predicted labels')
opts = parser.parse_args()
assert opts.true_db != opts.predicted_db

true_db = LabelDB(label_db_file=opts.true_db)
predicted_db = LabelDB(label_db_file=opts.predicted_db)

# iterate over predicted_db; we expect true_db to be a super set of it.
print("\t".join(["img", "#1_total", "#2_total", "ad", "#1_left", "#2_left"]))
total_TP = total_FP = total_FN = 0
for img in predicted_db.imgs():
  if not true_db.has_labels(img):
    # note: this can imply 0 labels
    raise Exception("img %s is in --predicted-db but not --true-db")

  true_labels = true_db.get_labels(img)
  predicted_labels = predicted_db.get_labels(img)
  TP, FP, FN = u.compare_sets(true_labels, predicted_labels)
  print("img", img, TP, FP, FN)

  total_TP += TP
  total_FP += FP
  total_FN += FN

precision = total_TP / ( total_TP + total_FP )
recall = total_TP / ( total_TP + total_FN )
f1 = 2 * (precision * recall) / (precision + recall)
print("precision %0.3f  recall %0.3f  f1 %0.3f" % ( precision, recall, f1))
Esempio n. 4
0
class LabelUI():
    def __init__(self, label_db_filename, img_dir, sort=True):

        # what images to review?
        # note: drop trailing / in dir name (if present)
        self.img_dir = re.sub("/$", "", img_dir)
        self.files = os.listdir(img_dir)
        if sort:
            self.files = sorted(self.files)
        else:
            random.shuffle(self.files)
        print("%d files to review" % len(self.files))

        # label db
        self.label_db = LabelDB(label_db_filename)
        self.label_db.create_if_required()

        # TK UI
        root = tk.Tk()
        root.title(label_db_filename)
        root.bind('<Right>', self.display_next_image)
        print("RIGHT  next image")
        root.bind('<Left>', self.display_previous_image)
        print("LEFT   previous image")
        root.bind('<Up>', self.toggle_bees)
        print("UP     toggle labels")
        root.bind('N', self.display_next_unlabelled_image)
        print("N   next image with 0 labels")
        self.canvas = tk.Canvas(root, cursor='tcross')
        self.canvas.config(width=768, height=1024)
        self.canvas.bind('<Button-1>', self.add_bee_event)  # left mouse button
        self.canvas.bind('<Button-3>',
                         self.remove_closest_bee_event)  # right mouse button
        self.canvas.pack()

        # A lookup table from bee x,y to any rectangles that have been drawn
        # in case we want to remove one. the keys of this dict represent all
        # the bee x,y in current image.
        self.x_y_to_boxes = {}  # { (x, y): canvas_id, ... }

        # a flag to denote if bees are being displayed or not
        # while no displayed we lock down all img navigation
        self.bees_on = True

        # Main review loop
        self.file_idx = 0
        self.display_new_image()
        root.mainloop()

    def add_bee_event(self, e):
        if not self.bees_on:
            print("ignore add bee; bees not on")
            return
        self.add_bee_at(e.x, e.y)

    def add_bee_at(self, x, y):
        rectangle_id = self.canvas.create_rectangle(x - 2,
                                                    y - 2,
                                                    x + 2,
                                                    y + 2,
                                                    fill='red')
        self.x_y_to_boxes[(x, y)] = rectangle_id

    def remove_bee(self, rectangle_id):
        self.canvas.delete(rectangle_id)

    def toggle_bees(self, e):
        if self.bees_on:
            # store x,y s in tmp list and delete all rectangles from canvas
            self.tmp_x_y = []
            for (x, y), rectangle_id in self.x_y_to_boxes.items():
                self.remove_bee(rectangle_id)
                self.tmp_x_y.append((x, y))
            self.x_y_to_boxes = {}
            self.bees_on = False
        else:  # bees not on
            # restore all temp stored bees
            for x, y in self.tmp_x_y:
                self.add_bee_at(x, y)
            self.bees_on = True

    def remove_closest_bee_event(self, e):
        if not self.bees_on:
            print("ignore remove bee; bees not on")
            return
        if len(self.x_y_to_boxes) == 0: return
        closest_point = None
        closest_sqr_distance = 0.0
        for x, y in self.x_y_to_boxes.keys():
            sqr_distance = (e.x - x)**2 + (e.y - y)**2
            if sqr_distance < closest_sqr_distance or closest_point is None:
                closest_point = (x, y)
                closest_sqr_distance = sqr_distance
        self.remove_bee(self.x_y_to_boxes.pop(closest_point))

    def display_next_image(self, e=None):
        if not self.bees_on:
            print("ignore move to next image; bees not on")
            return
        self._flush_pending_x_y_to_boxes()
        self.file_idx += 1
        if self.file_idx == len(self.files):
            print("Can't move to image past last image.")
            self.file_idx = len(self.files) - 1
        self.display_new_image()

    def display_next_unlabelled_image(self, e=None):
        self._flush_pending_x_y_to_boxes()
        while True:
            self.file_idx += 1
            if self.file_idx == len(self.files):
                print("Can't move to image past last image.")
                self.file_idx = len(self.files) - 1
                break
            if not self.label_db.has_labels(self.files[self.file_idx]):
                break
        self.display_new_image()

    def display_previous_image(self, e=None):
        if not self.bees_on:
            print("ignore move to previous image; bees not on")
            return
        self._flush_pending_x_y_to_boxes()
        self.file_idx -= 1
        if self.file_idx < 0:
            print("Can't move to image previous to first image.")
            self.file_idx = 0
        self.display_new_image()

    def _flush_pending_x_y_to_boxes(self):
        # Flush existing points.
        img_name = self.files[self.file_idx]
        if len(self.x_y_to_boxes) > 0:
            self.label_db.set_labels(img_name, self.x_y_to_boxes.keys())
            self.x_y_to_boxes.clear()

    def display_new_image(self):
        img_name = self.files[self.file_idx]
        # Display image (with filename added)
        img = Image.open(self.img_dir + "/" + img_name)
        canvas = ImageDraw.Draw(img)
        canvas.text((0, 0), img_name, fill='black')
        self.tk_img = ImageTk.PhotoImage(img)
        self.canvas.create_image(0, 0, image=self.tk_img, anchor=tk.NW)
        # Look up any existing bees in DB for this image.
        existing_labels = self.label_db.get_labels(img_name)
        for x, y in existing_labels:
            self.add_bee_at(x, y)
Esempio n. 5
0
# super clumsy :/

parser = argparse.ArgumentParser(
    formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--from-db',
                    type=str,
                    required=True,
                    help='db to take entries from')
parser.add_argument('--into-db',
                    type=str,
                    required=True,
                    help='db to add entries to')
opts = parser.parse_args()

assert opts.from_db != opts.into_db

from_db = LabelDB(label_db_file=opts.from_db)
into_db = LabelDB(label_db_file=opts.into_db)

num_ignored = 0
num_added = 0
for img in from_db.imgs():
    if into_db.has_labels(img):
        print("ignore", img, "; already in into_db")
        num_ignored += 1
    else:
        into_db.set_labels(img, from_db.get_labels(img))
        num_added += 1

print("num_ignored", num_ignored, "num_added", num_added)
Esempio n. 6
0
    # load next image (and add dummy batch dimension)
    img = np.array(Image.open(opts.image_dir + "/" + filename))  # uint8 0->255
    img = img.astype(np.float32)
    img = (img / 127.5) - 1.0  # -1.0 -> 1.0  # see data.py

    try:
        # run single image through model
        prediction = sess.run(model.output, feed_dict={model.imgs: [img]})[0]

        # calc [(x,y), ...] centroids
        centroids = u.centroids_of_connected_components(prediction,
                                                        rescale=2.0)

        pt_set_distance = 0.0
        if true_db is not None:
            true_centroids = true_db.get_labels(filename)
            print("PREDICTED", centroids)
            print("TRUE", true_centroids)
            pt_set_distance = u.compare_sets(true_pts=true_centroids,
                                             predicted_pts=centroids)

        print("\t".join(
            map(str, ["X", idx, filename, pt_set_distance,
                      len(centroids)])))

        # export some debug image (if requested)
        if opts.export_pngs != '':
            if opts.export_pngs == 'predictions':
                debug_img = u.side_by_side(rgb=img, bitmap=prediction)
            elif opts.export_pngs == 'centroids':
                debug_img = u.red_dots(rgb=img, centroids=centroids)