Esempio n. 1
0
    def test(self, run):
        with tf.Session() as sess:
            self.model.restore(sess, "ckpts/%s" % run)
            sess.run(self.iter_init_op)

            set_comparison = u.SetComparison()
            num_imgs = 0
            xent_losses = []
            debug_img = None  # created on first call
            while True:
                try:
                    if debug_img is None:
                        # fetch imgs as well to create debug_img
                        imgs, true_bitmaps, predicted_bitmaps, xent_loss = sess.run(
                            [
                                self.test_imgs, self.test_xys_bitmaps,
                                self.model.output, self.model.xent_loss
                            ])
                        # choose a random element from batch
                        idx = random.randint(0, true_bitmaps.shape[0] - 1)
                        debug_img = u.debug_img(imgs[idx], true_bitmaps[idx],
                                                predicted_bitmaps[idx])
                    else:
                        true_bitmaps, predicted_bitmaps, xent_loss = sess.run([
                            self.test_xys_bitmaps, self.model.output,
                            self.model.xent_loss
                        ])

                    xent_losses.append(xent_loss)
                    iterator_batch_size = true_bitmaps.shape[0]
                    num_imgs += iterator_batch_size

                    for idx in range(iterator_batch_size):
                        # this is dumb; should do against label db!
                        true_centroids = u.centroids_of_connected_components(
                            true_bitmaps[idx])
                        predicted_centroids = u.centroids_of_connected_components(
                            predicted_bitmaps[idx])
                        tp, fn, fp = set_comparison.compare_sets(
                            true_centroids, predicted_centroids)
                except tf.errors.OutOfRangeError:
                    # end of iterator
                    break

        precision, recall, f1 = set_comparison.precision_recall_f1()

        return {
            "num_imgs": num_imgs,
            "debug_img": debug_img,  # for tensorboard
            "precision": precision,
            "recall": recall,
            "f1": f1,
            "xent": np.mean(xent_losses)
        }
Esempio n. 2
0
File: test.py Progetto: yang-neu/bnn
def pr_stats(run, image_dir, label_db, connected_components_threshold):

    # TODO: a bunch of this can go back into one off init in a class

    _train_opts, model = m.restore_model(run)

    label_db = LabelDB(label_db_file=label_db)

    set_comparison = u.SetComparison()

    # use 4 images for debug
    debug_imgs = []

    for idx, filename in enumerate(sorted(os.listdir(image_dir))):
        # load next image
        # TODO: this block used in various places, refactor
        img = np.array(Image.open(image_dir + "/" +
                                  filename))  # uint8 0->255  (H, W)
        img = img.astype(np.float32)
        img = (img / 127.5) - 1.0  # -1.0 -> 1.0  # see data.py

        # run through model
        prediction = expit(model.predict(np.expand_dims(img, 0))[0])

        if len(debug_imgs) < 4:
            debug_imgs.append(u.side_by_side(rgb=img, bitmap=prediction))

        # calc [(x,y), ...] centroids
        predicted_centroids = u.centroids_of_connected_components(
            prediction, rescale=2.0, threshold=connected_components_threshold)

        # compare to true labels
        true_centroids = label_db.get_labels(filename)
        true_centroids = [(y, x) for (x, y) in true_centroids]  # sigh...
        tp, fn, fp = set_comparison.compare_sets(true_centroids,
                                                 predicted_centroids)

    precision, recall, f1 = set_comparison.precision_recall_f1()

    return {
        "debug_imgs": debug_imgs,
        "precision": precision,
        "recall": recall,
        "f1": f1
    }
Esempio n. 3
0
def dump_images(prefix):
    # run from imgs -> bitmap and stitch them together...
    img_collage = Image.new('RGB', (17 * 8, 17 * 8), (0, 0, 0))
    bitmap_collage = Image.new('RGB', (9 * 8, 9 * 8), (255, 255, 255))
    centroids_collage = Image.new('RGB', (9 * 8, 9 * 8), (255, 255, 255))
    ims, bs = sess.run([imgs, model.output])
    for x in range(8):
        for y in range(8):
            i = (x * 8) + y
            img_collage.paste(u.zero_centered_array_to_pil_image(ims[i]),
                              (17 * x, 17 * y))
            output_bitmap = u.bitmap_to_pil_image(bs[i])
            bitmap_collage.paste(output_bitmap, (9 * x, 9 * y))
            centroids = u.centroids_of_connected_components(bs[i])
            centroid_bitmap = u.bitmap_from_centroids(centroids, h=8, w=8)
            centroid_bitmap = u.bitmap_to_single_channel_pil_image(
                centroid_bitmap)
            centroids_collage.paste(centroid_bitmap, (9 * x, 9 * y))
    img_collage.save("images/ra/%s_imgs.png" % prefix)
    bitmap_collage.save("images/ra/%s_bitmaps.png" % prefix)
    centroids_collage.save("images/ra/%s_centroids.png" % prefix)
Esempio n. 4
0
  imgs = random.sample(imgs, opts.num)

for idx, filename in enumerate(sorted(imgs)):

  # load next image (and add dummy batch dimension)
  img = np.array(Image.open(opts.image_dir+"/"+filename))  # uint8 0->255
  img = img.astype(np.float32)
  img = (img / 127.5) - 1.0  # -1.0 -> 1.0  # see data.py

  try:
    # run single image through model
    prediction = sess.run(model.output, feed_dict={model.imgs: [img]})[0]

    # calc [(x,y), ...] centroids
    centroids = u.centroids_of_connected_components(prediction,
                                                    rescale=2.0,
                                                    threshold=opts.connected_components_threshold)

    print("\t".join(map(str, [idx, filename, len(centroids)])))

    # export some debug image (if requested)
    if opts.export_pngs != '':
      if opts.export_pngs == 'predictions':
        debug_img = u.side_by_side(rgb=img, bitmap=prediction)
      elif opts.export_pngs == 'centroids':
        debug_img = u.red_dots(rgb=img, centroids=centroids)
      else:
        raise Exception("unknown --export-pngs option")
      debug_img.save("%s/%s.png" % (export_dir, filename))

    # set new labels (if requested)
Esempio n. 5
0
    ])
    train_summaries_writer.add_summary(u.explicit_summaries({"xent": xl}),
                                       step)
    debug_img_summary = u.pil_image_to_tf_summary(u.debug_img(i, bm, o))
    train_summaries_writer.add_summary(debug_img_summary, step)
    train_summaries_writer.flush()

    # ... test
    i, bm, o, xl, step = sess.run([
        test_imgs, test_xys_bitmaps, test_model.output, test_model.xent_loss,
        global_step
    ])

    set_comparison = u.SetComparison()
    for idx in range(bm.shape[0]):
        true_centroids = u.centroids_of_connected_components(bm[idx])
        predicted_centroids = u.centroids_of_connected_components(o[idx])
        set_comparison.compare_sets(true_centroids, predicted_centroids)
    precision, recall, f1 = set_comparison.precision_recall_f1()
    tag_values = {
        "xent": xl,
        "precision": precision,
        "recall": recall,
        "f1": f1
    }
    test_summaries_writer.add_summary(u.explicit_summaries(tag_values), step)

    debug_img_summary = u.pil_image_to_tf_summary(u.debug_img(i, bm, o))
    test_summaries_writer.add_summary(debug_img_summary, step)
    test_summaries_writer.flush()
Esempio n. 6
0
    # load next image (and add dummy batch dimension)
    img = np.array(Image.open(opts.image_dir + "/" + filename))  # unit8 0->255
    img = img.astype(np.float32)
    img = (img / 127.5) - 1.0  # -1.0 -> 1.0  # see data.py

    try:
        # run single image through model
        s = time.time()
        prediction = sess.run(model_output,
                              feed_dict={imgs_placeholder: [img]})[0]
        prediction_time = time.time() - s
        prediction_times.append(prediction_time)

        # calc [(x,y), ...] centroids
        s = time.time()
        centroids = u.centroids_of_connected_components(prediction,
                                                        rescale=2.0)
        centroid_calc_time = time.time() - s
        centroid_calc_times.append(centroid_calc_time)

        print("\t".join(
            map(str, [
                idx, filename,
                len(centroids), prediction_time, centroid_calc_time
            ])))

        # export some debug image (if requested)
        if opts.export_pngs != '':
            if opts.export_pngs == 'predictions':
                debug_img = u.side_by_side(rgb=img, bitmap=prediction)
            elif opts.export_pngs == 'centroids':
                debug_img = u.red_dots(rgb=img, centroids=centroids)
Esempio n. 7
0
for idx, filename in enumerate(sorted(imgs)):

    # load next image
    img = np.array(Image.open(opts.image_dir + "/" +
                              filename))  # uint8 0->255  (H, W)
    img = img.astype(np.float32)
    img = (img / 127.5) - 1.0  # -1.0 -> 1.0  # see data.py

    # run through model (adding / removing dummy batch)
    # recall: output from model is logits so we need to expit
    # TODO: do this in batch !!
    prediction = expit(model.predict(np.expand_dims(img, 0))[0])

    # calc [(x,y), ...] centroids
    centroids = u.centroids_of_connected_components(
        prediction,
        rescale=2.0,
        threshold=train_opts['connected_components_threshold'])
    print("\t".join(map(str, [idx, filename, len(centroids)])))

    # export some debug image (if requested)
    if opts.export_pngs != '':
        if opts.export_pngs == 'predictions':
            debug_img = u.side_by_side(rgb=img, bitmap=prediction)
        elif opts.export_pngs == 'centroids':
            debug_img = u.red_dots(rgb=img, centroids=centroids)
        else:
            raise Exception("unknown --export-pngs option")
        debug_img.save("%s/%s.png" % (export_dir, filename))

    # set new labels (if requested)
    if db: