Exemplo n.º 1
0
# FIXME: Test hyperparameters
hp_d['batch_size'] = 8
""" 3. Build graph, load weights, initialize a session and start test """
# Initialize
graph = tf.get_default_graph()
config = tf.ConfigProto()
config.gpu_options.allow_growth = True

model = ConvNet([IM_SIZE[0], IM_SIZE[1], 3], NUM_CLASSES, **hp_d)
evaluator = Evaluator()
saver = tf.train.Saver()

sess = tf.Session(graph=graph, config=config)
saver.restore(sess, './model.ckpt')  # restore learned weights
test_y_pred = model.predict(sess, test_set, **hp_d)
test_score = evaluator.score(test_set.labels, test_y_pred)

print('Test accuracy: {}'.format(test_score))
""" 4. Draw masks on image """
draw_dir = os.path.join(test_dir, 'draws')  # FIXME
if not os.path.isdir(draw_dir):
    os.mkdir(draw_dir)
im_dir = os.path.join(test_dir, 'images')  # FIXME
im_paths = []
im_paths.extend(glob.glob(os.path.join(im_dir, '*.jpg')))
test_outputs = draw_pixel(test_y_pred)
test_results = test_outputs + test_set.images
for img, im_path in zip(test_results, im_paths):
    name = im_path.split('/')[-1]
    draw_path = os.path.join(draw_dir, name)
    cv2.imwrite(draw_path, img)
Exemplo n.º 2
0
""" 3. Build graph, load weights, initialize a session """
# Initialize
graph = tf.get_default_graph()
config = tf.ConfigProto()
config.gpu_options.allow_growth = True

model = ConvNet([IM_SIZE[0], IM_SIZE[1], 3], NUM_CLASSES, **hp_d)
saver = tf.train.Saver()

sess = tf.Session(graph=graph, config=config)
saver.restore(sess, './model.ckpt')    # restore learned weights

capture = cv2.VideoCapture(0)
capture.set(cv2.CAP_PROP_FRAME_WIDTH, 512)
capture.set(cv2.CAP_PROP_FRAME_HEIGHT, 512)


while True:
    ret, frame = capture.read()
    resize = cv2.resize(frame, dsize=(512, 512), interpolation=cv2.INTER_AREA)

    test_y_pred = model.predict_video(sess, [resize, ], **hp_d)
    mask = draw_pixel(test_y_pred)
    result = mask.reshape(512, 512, 3)

    cv2.imshow("origin", resize)
    cv2.imshow("mask", result)
    if cv2.waitKey(1) > 0: break

capture.release()
cv2.destroyAllWindows()