def show_kps(im_l, im_r, im_border, im_mask, kps, obj=None, size=3): """Draw left/right images and keypoints using OpenCV.""" cam, uvds, _ = kps im_l = cv2.cvtColor(im_l, cv2.COLOR_BGR2RGB) im_r = cv2.cvtColor(im_r, cv2.COLOR_BGR2RGB) uvds = np.array(uvds) for i, uvd in enumerate(uvds): draw_circle(im_l, uvd, colors[i * 3], size) if obj: p_matrix = utils.p_matrix_from_camera(cam) q_matrix = utils.q_matrix_from_camera(cam) xyzs = utils.project_np(q_matrix, uvds.T) obj.project_to_uvd(xyzs, p_matrix) im_l = obj.draw_points(im_l) cv2.imshow('Image Left', im_l) cv2.imshow('Border', im_border) cv2.imshow('Mask', im_mask) key = cv2.waitKey() if key == ord('q'): return True return False
def predict(model_dir, image_dir, params, camera, camera_input, mesh_file=None): """Predicts keypoints on all images in dset_dir, evaluates 3D accuracy.""" # Set up network for prediction. mparams = params.model_params num_kp = mparams.num_kp # Set up predictor from model directory. model = tf.saved_model.load(model_dir) predict_fn = model.signatures['serving_default'] colors = plt.cm.get_cmap('rainbow')(np.linspace(0, 1.0, num_kp))[:, :3] colors = (colors * 255).tolist() # Set up mesh object, if the mesh file exists. obj = None if mesh_file and mesh_file != 'show': obj = utils.read_mesh(mesh_file, num=300) obj.large_points = False # Iterate over all images in image_dir. total_time = 0.0 count = 0 mae_list = [] filenames = glob.glob(os.path.join(image_dir, '*_L.png')) filenames.sort() for fname in filenames: # Read in and resize images if necessary. print(fname) im_l = utils.read_image(fname) targs_pb = utils.read_target_pb(fname.replace('_L.png', '_L.pbtxt')) kps_pb = targs_pb.kp_target im_l = utils.resize_image(im_l, camera, camera_input, targs_pb) # NB: changes targs_pb. keys_uvd_l, to_world_l, visible_l = utils.get_keypoints(kps_pb) im_r = utils.read_image(fname.replace('_L.png', '_R.png')) targs_pb = utils.read_target_pb(fname.replace('_L.png', '_R.pbtxt')) kps_pb = targs_pb.kp_target im_r = utils.resize_image(im_r, camera, camera_input, targs_pb) # NB: changes kps_pb. keys_uvd_r, _, _ = utils.get_keypoints(kps_pb) # Do cropping if called out in mparams. if mparams.crop: img0, img1, offs, visible = utils.do_occlude_crop(im_l, im_r, keys_uvd_l, keys_uvd_r, mparams.crop, visible_l, dither=0.0, var_offset=False) if np.any(visible == 0.0): print('Could not crop') continue offsets = offs.astype(np.float32) hom = np.eye(3, dtype=np.float32) to_world_l = to_world_l.astype(np.float32) keys_uvd_l = keys_uvd_l.astype(np.float32) # Batch size of 1. img_l = tf.constant(np.expand_dims(img0[:, :, :3], 0)) img_r = tf.constant(np.expand_dims(img1[:, :, :3], 0)) to_world_l = tf.constant(np.expand_dims(to_world_l, 0)) keys_uvd_l = tf.constant(np.expand_dims(keys_uvd_l, 0)) offsets = tf.constant(np.expand_dims(offsets, 0)) hom = tf.constant(np.expand_dims(hom, 0)) labels = {} labels['keys_uvd_L'] = keys_uvd_l labels['to_world_L'] = to_world_l # Now do the magic. t0 = time.time() preds = predict_fn(img_L=img_l, img_R=img_r, to_world_L=to_world_l, offsets=offsets, hom=hom) if count > 0: # Ignore first time, startup is long. total_time += time.time() - t0 count += 1 xyzw = tf.transpose(preds['xyzw'], [0, 2, 1]) mae_3d = utils.world_error(labels, xyzw).numpy() mae_list.append(mae_3d) print('mae_3d:', mae_3d) uvdw = preds['uvdw'][0, Ellipsis].numpy() xyzw = xyzw[0, Ellipsis].numpy() offsets = offsets[0, Ellipsis].numpy() # uv_pix_raw is in the coords of the cropped image used by the model. uv_pix_raw = preds['uv_pix_raw'][0, Ellipsis].numpy() # [num_kp, 3] img_l = image_as_ubyte(img_l[0, Ellipsis].numpy()) img_r = image_as_ubyte(img_r[0, Ellipsis].numpy()) iml_orig = cv2.resize(img_l, None, fx=2, fy=2) imr_orig = cv2.resize(img_r, None, fx=2, fy=2) if obj: p_matrix = utils.p_matrix_from_camera(camera) q_matrix = utils.q_matrix_from_camera(camera) xyzw_cam = utils.project_np(q_matrix, uvdw.T) obj.project_to_uvd(xyzw_cam, p_matrix) im_mesh = np.array(img_l) im_mesh = obj.draw_points(im_mesh, offsets) im_mesh = cv2.resize(im_mesh, None, fx=2, fy=2) else: im_mesh = np.zeros(iml_orig.shape, dtype=np.uint8) for i in range(uv_pix_raw.shape[0]): draw_circle(img_l, uv_pix_raw[i, :2], colors[i]) im_kps = cv2.resize(img_l, None, fx=2, fy=2) im_large = cv2.vconcat([ cv2.hconcat([iml_orig, imr_orig]), cv2.hconcat([im_kps, im_mesh]) ]) if mesh_file: cv2.imshow('Left and right images; keypoints and mesh', im_large) key = cv2.waitKey() if key == ord('q'): break print('Total time: %.1f for %d inferences, %.1f ms/inf' % (total_time, count, total_time * 1000.0 / count)) mae_list = np.concatenate(mae_list).flatten() print('MAE 3D (m): %f' % np.mean(mae_list))