def show_keypoints(image_dir, mesh_file): """Display keypoints in a keypose data file.""" print('Looking for images in %s' % image_dir) filenames = glob.glob(os.path.join(image_dir, '*_L.png')) if not filenames: print("Couldn't find any PNG files in %s" % image_dir) exit(-1) filenames.sort() print('Found %d files in %s' % (len(filenames), image_dir)) obj = None if mesh_file: obj = utils.read_mesh(mesh_file) for fname in filenames: im_l = utils.read_image(fname) im_r = utils.read_image(fname.replace('_L.png', '_R.png')) im_mask = utils.read_image(fname.replace('_L.png', '_mask.png')) im_border = utils.read_image(fname.replace('_L.png', '_border.png')) cam, _, _, uvds, _, _, transform = utils.read_contents_pb( fname.replace('_L.png', '_L.pbtxt')) print(fname) ret = show_kps(im_l, im_r, im_mask, im_border, (cam, uvds, transform), obj) if ret: break
def predict(model_dir, image_dir, params, camera, camera_input, mesh_file=None): """Predicts keypoints on all images in dset_dir, evaluates 3D accuracy.""" # Set up network for prediction. mparams = params.model_params num_kp = mparams.num_kp # Set up predictor from model directory. model = tf.saved_model.load(model_dir) predict_fn = model.signatures['serving_default'] colors = plt.cm.get_cmap('rainbow')(np.linspace(0, 1.0, num_kp))[:, :3] colors = (colors * 255).tolist() # Set up mesh object, if the mesh file exists. obj = None if mesh_file and mesh_file != 'show': obj = utils.read_mesh(mesh_file, num=300) obj.large_points = False # Iterate over all images in image_dir. total_time = 0.0 count = 0 mae_list = [] filenames = glob.glob(os.path.join(image_dir, '*_L.png')) filenames.sort() for fname in filenames: # Read in and resize images if necessary. print(fname) im_l = utils.read_image(fname) targs_pb = utils.read_target_pb(fname.replace('_L.png', '_L.pbtxt')) kps_pb = targs_pb.kp_target im_l = utils.resize_image(im_l, camera, camera_input, targs_pb) # NB: changes targs_pb. keys_uvd_l, to_world_l, visible_l = utils.get_keypoints(kps_pb) im_r = utils.read_image(fname.replace('_L.png', '_R.png')) targs_pb = utils.read_target_pb(fname.replace('_L.png', '_R.pbtxt')) kps_pb = targs_pb.kp_target im_r = utils.resize_image(im_r, camera, camera_input, targs_pb) # NB: changes kps_pb. keys_uvd_r, _, _ = utils.get_keypoints(kps_pb) # Do cropping if called out in mparams. if mparams.crop: img0, img1, offs, visible = utils.do_occlude_crop(im_l, im_r, keys_uvd_l, keys_uvd_r, mparams.crop, visible_l, dither=0.0, var_offset=False) if np.any(visible == 0.0): print('Could not crop') continue offsets = offs.astype(np.float32) hom = np.eye(3, dtype=np.float32) to_world_l = to_world_l.astype(np.float32) keys_uvd_l = keys_uvd_l.astype(np.float32) # Batch size of 1. img_l = tf.constant(np.expand_dims(img0[:, :, :3], 0)) img_r = tf.constant(np.expand_dims(img1[:, :, :3], 0)) to_world_l = tf.constant(np.expand_dims(to_world_l, 0)) keys_uvd_l = tf.constant(np.expand_dims(keys_uvd_l, 0)) offsets = tf.constant(np.expand_dims(offsets, 0)) hom = tf.constant(np.expand_dims(hom, 0)) labels = {} labels['keys_uvd_L'] = keys_uvd_l labels['to_world_L'] = to_world_l # Now do the magic. t0 = time.time() preds = predict_fn(img_L=img_l, img_R=img_r, to_world_L=to_world_l, offsets=offsets, hom=hom) if count > 0: # Ignore first time, startup is long. total_time += time.time() - t0 count += 1 xyzw = tf.transpose(preds['xyzw'], [0, 2, 1]) mae_3d = utils.world_error(labels, xyzw).numpy() mae_list.append(mae_3d) print('mae_3d:', mae_3d) uvdw = preds['uvdw'][0, Ellipsis].numpy() xyzw = xyzw[0, Ellipsis].numpy() offsets = offsets[0, Ellipsis].numpy() # uv_pix_raw is in the coords of the cropped image used by the model. uv_pix_raw = preds['uv_pix_raw'][0, Ellipsis].numpy() # [num_kp, 3] img_l = image_as_ubyte(img_l[0, Ellipsis].numpy()) img_r = image_as_ubyte(img_r[0, Ellipsis].numpy()) iml_orig = cv2.resize(img_l, None, fx=2, fy=2) imr_orig = cv2.resize(img_r, None, fx=2, fy=2) if obj: p_matrix = utils.p_matrix_from_camera(camera) q_matrix = utils.q_matrix_from_camera(camera) xyzw_cam = utils.project_np(q_matrix, uvdw.T) obj.project_to_uvd(xyzw_cam, p_matrix) im_mesh = np.array(img_l) im_mesh = obj.draw_points(im_mesh, offsets) im_mesh = cv2.resize(im_mesh, None, fx=2, fy=2) else: im_mesh = np.zeros(iml_orig.shape, dtype=np.uint8) for i in range(uv_pix_raw.shape[0]): draw_circle(img_l, uv_pix_raw[i, :2], colors[i]) im_kps = cv2.resize(img_l, None, fx=2, fy=2) im_large = cv2.vconcat([ cv2.hconcat([iml_orig, imr_orig]), cv2.hconcat([im_kps, im_mesh]) ]) if mesh_file: cv2.imshow('Left and right images; keypoints and mesh', im_large) key = cv2.waitKey() if key == ord('q'): break print('Total time: %.1f for %d inferences, %.1f ms/inf' % (total_time, count, total_time * 1000.0 / count)) mae_list = np.concatenate(mae_list).flatten() print('MAE 3D (m): %f' % np.mean(mae_list))