示例#1
0
文件: test.py 项目: yang-neu/bnn
def pr_stats(run, image_dir, label_db, connected_components_threshold):

    # TODO: a bunch of this can go back into one off init in a class

    _train_opts, model = m.restore_model(run)

    label_db = LabelDB(label_db_file=label_db)

    set_comparison = u.SetComparison()

    # use 4 images for debug
    debug_imgs = []

    for idx, filename in enumerate(sorted(os.listdir(image_dir))):
        # load next image
        # TODO: this block used in various places, refactor
        img = np.array(Image.open(image_dir + "/" +
                                  filename))  # uint8 0->255  (H, W)
        img = img.astype(np.float32)
        img = (img / 127.5) - 1.0  # -1.0 -> 1.0  # see data.py

        # run through model
        prediction = expit(model.predict(np.expand_dims(img, 0))[0])

        if len(debug_imgs) < 4:
            debug_imgs.append(u.side_by_side(rgb=img, bitmap=prediction))

        # calc [(x,y), ...] centroids
        predicted_centroids = u.centroids_of_connected_components(
            prediction, rescale=2.0, threshold=connected_components_threshold)

        # compare to true labels
        true_centroids = label_db.get_labels(filename)
        true_centroids = [(y, x) for (x, y) in true_centroids]  # sigh...
        tp, fn, fp = set_comparison.compare_sets(true_centroids,
                                                 predicted_centroids)

    precision, recall, f1 = set_comparison.precision_recall_f1()

    return {
        "debug_imgs": debug_imgs,
        "precision": precision,
        "recall": recall,
        "f1": f1
    }
示例#2
0
文件: data.py 项目: opencvfun/bnn
                      help='relative scale of label bitmap compared to input image')
  parser.add_argument('--distort', action='store_true')
  parser.add_argument('--rotate', action='store_true')
  parser.add_argument('--width', type=int, default=None, help='input image width. required if --patch-width-height not set.')
  parser.add_argument('--height', type=int, default=None, help='input image height. required if --patch-width-height not set.')
  opts = parser.parse_args()
  print(opts)

  from PIL import Image, ImageDraw

  sess = tf.Session()

  imgs, xyss = img_xys_iterator(image_dir=opts.image_dir,
                                label_dir=opts.label_dir,
                                batch_size=opts.batch_size,
                                patch_width_height=opts.patch_width_height,
                                distort_rgb=opts.distort,
                                flip_left_right=True,
                                random_rotation=opts.rotate,
                                repeat=True,
                                width=opts.width,
                                height=opts.height,
                                label_rescale=opts.label_rescale)

  for b in range(3):
    img_batch, xys_batch = sess.run([imgs, xyss])
    for i, (img, xys) in enumerate(zip(img_batch, xys_batch)):
      fname = "test_%03d_%03d.png" % (b, i)
      print("batch", b, "element", i, "fname", fname)
      u.side_by_side(rgb=img, bitmap=xys).save(fname)
示例#3
0
  img = (img / 127.5) - 1.0  # -1.0 -> 1.0  # see data.py

  try:
    # run single image through model
    prediction = sess.run(model.output, feed_dict={model.imgs: [img]})[0]

    # calc [(x,y), ...] centroids
    centroids = u.centroids_of_connected_components(prediction,
                                                    rescale=2.0,
                                                    threshold=opts.connected_components_threshold)

    print("\t".join(map(str, [idx, filename, len(centroids)])))

    # export some debug image (if requested)
    if opts.export_pngs != '':
      if opts.export_pngs == 'predictions':
        debug_img = u.side_by_side(rgb=img, bitmap=prediction)
      elif opts.export_pngs == 'centroids':
        debug_img = u.red_dots(rgb=img, centroids=centroids)
      else:
        raise Exception("unknown --export-pngs option")
      debug_img.save("%s/%s.png" % (export_dir, filename))

    # set new labels (if requested)
    if db:
      db.set_labels(filename, centroids, flip=True)

  except tf.errors.OutOfRangeError:
    # end of iterator
    break
示例#4
0
文件: data.py 项目: bgolubovski/bees
  parser.add_argument('--image-dir', type=str, default='images/201802_sample/training',
                      help='location of RGB input images')
  parser.add_argument('--label-dir', type=str, default='labels/201802_sample',
                      help='location of corresponding L label files')
  parser.add_argument('--batch-size', type=int, default=4)
  parser.add_argument('--patch-fraction', type=int, default=1,
                      help="what fraction of image to use as patch. 1 => no patch")
  parser.add_argument('--distort', action='store_true')
  opts = parser.parse_args()
  print(opts)

  from PIL import Image, ImageDraw

  sess = tf.Session()

  imgs, xyss = img_xys_iterator(image_dir=opts.image_dir,
                                label_dir=opts.label_dir,
                                batch_size=opts.batch_size,
                                patch_fraction=opts.patch_fraction,
                                distort_rgb=opts.distort,
                                flip_left_right=True,
                                random_rotation=True,
                                repeat=True)

  for b in range(3):
    print(">batch", b)
    img_batch, xys_batch = sess.run([imgs, xyss])
    for i, (img, xys) in enumerate(zip(img_batch, xys_batch)):
      print(">element", i)
      u.side_by_side(rgb=img, bitmap=xys).save("test_%03d_%03d.png" % (b, i))