def _run_and_print(self, x_world, y, R, T, print_stats, skip_q, check_output=False): sol = perspective_n_points.efficient_pnp(x_world, y.expand_as( x_world[:, :, :2]), skip_quadratic_eq=skip_q) err_2d = reproj_error(x_world, y, sol.R, sol.T) R_est_quat = rotation_conversions.matrix_to_quaternion(sol.R) R_quat = rotation_conversions.matrix_to_quaternion(R) num_pts = x_world.shape[-2] # quadratic part is more stable with fewer points num_pts_thresh = 5 if skip_q else 4 if check_output and num_pts > num_pts_thresh: assert_msg = (f"test_perspective_n_points assertion failure for " f"n_points={num_pts}, " f"skip_quadratic={skip_q}, " f"no noise.") self.assertClose(err_2d, sol.err_2d, msg=assert_msg) self.assertTrue((err_2d < 1e-3).all(), msg=assert_msg) def norm_fn(t): return t.norm(dim=-1) self.assertNormsClose(T, sol.T[:, None, :], rtol=4e-3, norm_fn=norm_fn, msg=assert_msg) self.assertNormsClose(R_quat, R_est_quat, rtol=3e-3, norm_fn=norm_fn, msg=assert_msg) if print_stats: torch.set_printoptions(precision=5, sci_mode=False) for err_2d, err_3d, R_gt, T_gt in zip( sol.err_2d, sol.err_3d, torch.cat((sol.R, R), dim=-1), torch.stack((sol.T, T[:, 0, :]), dim=-1), ): print("2D Error: %1.4f" % err_2d.item()) print("3D Error: %1.4f" % err_3d.item()) print("R_hat | R_gt\n", R_gt) print("T_hat | T_gt\n", T_gt)
def test_weighted_perspective_n_points(self, batch_size=16, num_pts=200): # instantiate random x_world and y y = torch.randn((batch_size, num_pts, 2)).cuda() / 3.0 x_cam, x_world, R, T = TestPerspectiveNPoints._generate_epnp_test_from_2d( y) # randomly drop 50% of the rows weights = (torch.rand_like(x_world[:, :, 0]) > 0.5).float() # make sure we retain at least 6 points for each case weights[:, :6] = 1.0 # fill ignored y with trash to ensure that we get different # solution in case the weighting is wrong y = y + (1 - weights[:, :, None]) * 100.0 def norm_fn(t): return t.norm(dim=-1) for skip_quadratic_eq in (True, False): # get the solution for the 0/1 weighted case sol = perspective_n_points.efficient_pnp( x_world, y, skip_quadratic_eq=skip_quadratic_eq, weights=weights) sol_R_quat = rotation_conversions.matrix_to_quaternion(sol.R) sol_T = sol.T # check that running only on points with non-zero weights ends in the # same place as running the 0/1 weighted version for i in range(batch_size): ok = weights[i] > 0 x_world_ok = x_world[i, ok][None] y_ok = y[i, ok][None] sol_ok = perspective_n_points.efficient_pnp( x_world_ok, y_ok, skip_quadratic_eq=False) R_est_quat_ok = rotation_conversions.matrix_to_quaternion( sol_ok.R) self.assertNormsClose(sol_T[i], sol_ok.T[0], rtol=3e-3, norm_fn=norm_fn) self.assertNormsClose(sol_R_quat[i], R_est_quat_ok[0], rtol=3e-4, norm_fn=norm_fn)
def interpolate_cameras(C, R): if torch.isfinite(C).all(): return C.clone(), R.clone() from pytorch3d.transforms import rotation_conversions from scipy.interpolate import interp1d ok = torch.isfinite(C.mean(1)) quats = rotation_conversions.matrix_to_quaternion(R) n_frames = C.shape[0] y = torch.cat((quats, C), dim=1).numpy() x = torch.arange(n_frames).float().numpy() ok = np.isfinite(y.mean(1)) fi = interp1d( x[ok], y[ok], kind='linear', bounds_error=False, axis=0, fill_value=(y[ok][0], y[ok][-1]) ) y_interp = fi(x) i_quats = torch.tensor(y_interp[:, :4]).float() i_R = rotation_conversions.quaternion_to_matrix(i_quats) i_C = torch.tensor(y_interp[:, 4:]).float() return i_C, i_R
def test_quat_grad_exists(self): """Quaternion calculations are differentiable.""" rotation = random_rotation() rotation.requires_grad = True modified = quaternion_to_matrix(matrix_to_quaternion(rotation)) [g] = torch.autograd.grad(modified.sum(), rotation) self.assertTrue(torch.isfinite(g).all())
def preprocess_poses(cls, poses: tuple): """Generates (N, 6) vector of absolute poses Args: Tuple of batched rotations (N, 3, 3) and translations (N, 3) in Pytorch3d view-to-world coordinates. usually returned from a call to RenderManager._trajectory More information about Pytorch3D's coordinate system: https://github.com/facebookresearch/pytorch3d/blob/master/docs/notes/cameras.md 1. Computes rotation and translation matrices in view-to-world coordinates. 2. Generates unit quaternion from R and computes log q repr 3. Normalizes translation according to mean and stdev Returns: (N, 6) vector: [t1, t2, t3, logq1, logq2, logq3] """ R, T = poses cam_wvt = get_world_to_view_transform(R=R, T=T) pose_transform = cam_wvt.inverse().get_matrix() T = pose_transform[:, 3, :3] R = pose_transform[:, :3, :3] # Compute pose stats std_R, mean_R = torch.std_mean(R) std_T, mean_T = torch.std_mean(T) q = rc.matrix_to_quaternion(R) # q /= torch.norm(q) # q *= torch.sign(q[0]) # hemisphere constraint # logq = qlog(q) T -= mean_T T /= std_T return torch.cat((T, q), dim=1)
def test_quaternion_multiplication(self): """Quaternion and matrix multiplication are equivalent.""" a = random_quaternions(15, torch.float64).reshape((3, 5, 4)) b = random_quaternions(21, torch.float64).reshape((7, 3, 1, 4)) ab = quaternion_multiply(a, b) self.assertEqual(ab.shape, (7, 3, 5, 4)) a_matrix = quaternion_to_matrix(a) b_matrix = quaternion_to_matrix(b) ab_matrix = torch.matmul(a_matrix, b_matrix) ab_from_matrix = matrix_to_quaternion(ab_matrix) self._assert_quaternions_close(ab, ab_from_matrix)
def test_matrix_to_quaternion_corner_case(self): """Check no bad gradients from sqrt(0).""" matrix = torch.eye(3, requires_grad=True) target = torch.Tensor([0.984808, 0, 0.174, 0]) optimizer = torch.optim.Adam([matrix], lr=0.05) optimizer.zero_grad() q = matrix_to_quaternion(matrix) loss = torch.sum((q - target)**2) loss.backward() optimizer.step() self.assertClose(matrix, 0.95 * torch.eye(3))
def interp_rotation(r1, r2, interp_factor): ''' Given two rotation matrices r1 and r2, returns a rotation that is interp_factor between them; when factor=0, returns r1, and when factor=1, returns r2. Linearly interpolates along the geodesic between the rotations by converting the relative rotation to angle-axis and scaling the angle by interp_factor. If r1 and r2 pi radians apart, the returned rotation axis will be arbitrary. ''' assert interp_factor >= 0. and interp_factor <= 1. # Convert to angle-axis, interpolate angle, convert back. # When interp_factor = 0, this return r1. When interp_factor = 1, this # returns 1. rel = torch.matmul(r2, r1.transpose(0, 1)) rel_axis_angle = quaternion_to_axis_angle(matrix_to_quaternion(rel)) # Scaling keeps axis the same, but changes angle. scaled_rel_axis_angle = interp_factor * rel_axis_angle return torch.matmul(axis_angle_to_matrix(scaled_rel_axis_angle), r1)
def test_matrix_to_quaternion_corner_case(self): """Check no bad gradients from sqrt(0).""" matrix = torch.eye(3, requires_grad=True) target = torch.Tensor([0.984808, 0, 0.174, 0]) optimizer = torch.optim.Adam([matrix], lr=0.05) optimizer.zero_grad() q = matrix_to_quaternion(matrix) loss = torch.sum((q - target) ** 2) loss.backward() optimizer.step() self.assertClose(matrix, matrix, msg="Result has non-finite values") delta = 1e-2 self.assertLess( matrix.trace(), 3.0 - delta, msg="Identity initialisation unchanged by a gradient step", )
def test_matrix_to_quaternion_by_pi(self): # We check that rotations by pi around each of the 26 # nonzero vectors containing nothing but 0, 1 and -1 # are mapped to the right quaternions. # This is representative across the directions. options = [0.0, -1.0, 1.0] axes = [ torch.tensor(vec) for vec in itertools.islice( # exclude [0, 0, 0] itertools.product(options, options, options), 1, None ) ] axes = torch.nn.functional.normalize(torch.stack(axes), dim=-1) # Rotation by pi around unit vector x is given by # the matrix 2 x x^T - Id. R = 2 * torch.matmul(axes[..., None], axes[..., None, :]) - torch.eye(3) quats_hat = matrix_to_quaternion(R) R_hat = quaternion_to_matrix(quats_hat) self.assertClose(R, R_hat, atol=1e-3)
def test_to_quat(self): """mtx -> quat -> mtx""" data = random_rotations(13, dtype=torch.float64) mdata = quaternion_to_matrix(matrix_to_quaternion(data)) self.assertTrue(torch.allclose(data, mdata))
def test_from_quat(self): """quat -> mtx -> quat""" data = random_quaternions(13, dtype=torch.float64) mdata = matrix_to_quaternion(quaternion_to_matrix(data)) self.assertTrue(torch.allclose(data, mdata))
def pred_synth(segpose, params: Params, mesh_type: str = "dolphin", device: str = "cuda"): if not params.pred_dir and not os.path.exists(params.pred_dir): raise FileNotFoundError( "Prediction directory has not been set or the file does not exist, please set using cli args or params" ) pred_folders = [ join(params.pred_dir, f) for f in os.listdir(params.pred_dir) ] count = 1 for p in sorted(pred_folders): try: print(p) manager = RenderManager.from_path(p) manager.rectify_paths(base_folder=params.pred_dir) except FileNotFoundError: continue # Run Silhouette Prediction Network logging.info(f"Starting mask predictions") mask_priors = [] R_pred, T_pred = [], [] q_loss, t_loss = 0, 0 # Collect Translation stats R_gt, T_gt = manager._trajectory poses_gt = EvMaskPoseDataset.preprocess_poses(manager._trajectory) std_T, mean_T = torch.std_mean(T_gt) for idx in range(len(manager)): try: ev_frame = manager.get_event_frame(idx) except Exception as e: print(e) break mask_pred, pose_pred = predict_segpose(segpose, ev_frame, params.threshold_conf, params.img_size) # mask_pred = smooth_predicted_mask(mask_pred) manager.add_pred(idx, mask_pred, "silhouette") mask_priors.append(torch.from_numpy(mask_pred)) # Make qexp a torch function # q_pred = qexp(pose_pred[:, 3:]) # q_targ = qexp(poses_gt[idx, 3:].unsqueeze(0)) #### SHOULD THIS BE NORMALIZED ?? q_pred = pose_pred[:, 3:] q_targ = poses_gt[idx, 3:] q_pred_unit = q_pred / torch.norm(q_pred) q_targ_unit = q_targ / torch.norm(q_targ) # print("learnt: ", q_pred_unit, q_targ_unit) t_pred = pose_pred[:, :3] * std_T + mean_T t_targ = poses_gt[idx, :3] * std_T + mean_T T_pred.append(t_pred) q_loss += quaternion_angular_error(q_pred_unit, q_targ_unit) t_loss += t_error(t_pred, t_targ) r_pred = rc.quaternion_to_matrix(q_pred).unsqueeze(0) R_pred.append(r_pred.squeeze(0)) q_loss_mean = q_loss / (idx + 1) t_loss_mean = t_loss / (idx + 1) # Convert R,T to world-to-view transforms --> Pytorch3d convention for the : R_pred_abs = torch.cat(R_pred) T_pred_abs = torch.cat(T_pred) # Take inverse of view-to-world (output of net) to obtain w2v wtv_trans = (get_world_to_view_transform( R=R_pred_abs, T=T_pred_abs).inverse().get_matrix()) T_pred = wtv_trans[:, 3, :3] R_pred = wtv_trans[:, :3, :3] R_pred_test = look_at_rotation(T_pred_abs) T_pred_test = -torch.bmm(R_pred_test.transpose(1, 2), T_pred_abs[:, :, None])[:, :, 0] # Convert back to view-to-world to get absolute vtw_trans = (get_world_to_view_transform( R=R_pred_test, T=T_pred_test).inverse().get_matrix()) T_pred_trans = vtw_trans[:, 3, :3] R_pred_trans = vtw_trans[:, :3, :3] # Calc pose error for this: q_loss_mean_test = 0 t_loss_mean_test = 0 for idx in range(len(R_pred_test)): q_pred_trans = rc.matrix_to_quaternion(R_pred_trans[idx]).squeeze() q_targ = poses_gt[idx, 3:] q_targ_unit = q_targ / torch.norm(q_targ) # print("look: ", q_test, q_targ) q_loss_mean_test += quaternion_angular_error( q_pred_trans, q_targ_unit) t_targ = poses_gt[idx, :3] * std_T + mean_T t_loss_mean_test += t_error(T_pred_trans[idx], t_targ) q_loss_mean_test /= idx + 1 t_loss_mean_test /= idx + 1 logging.info( f"Mean Translation Error: {t_loss_mean}; Mean Rotation Error: {q_loss_mean}" ) logging.info( f"Mean Translation Error: {t_loss_mean_test}; Mean Rotation Error: {q_loss_mean_test}" ) # Plot estimated cameras logging.info(f"Plotting pose map") idx = random.sample(range(len(R_gt)), k=2) pose_plot = plot_cams_from_poses( (R_gt[idx], T_gt[idx]), (R_pred[idx], T_pred[idx]), params.device) pose_plot_test = plot_cams_from_poses( (R_gt[idx], T_gt[idx]), (R_pred_test[idx], T_pred_test[idx]), params.device) manager.add_pose_plot(pose_plot, "rot+trans") manager.add_pose_plot(pose_plot_test, "trans") count += 1 groundtruth_silhouettes = manager._images("silhouette") / 255.0 print(groundtruth_silhouettes.shape) print(torch.stack((mask_priors)).shape) seg_iou = neg_iou_loss(groundtruth_silhouettes, torch.stack((mask_priors)) / 255.0) print("Seg IoU", seg_iou) # RUN MESH DEFORMATION # RUN MESH DEFORMATION # Run it 3 times: w/ Rot+Trans - w/ Trans+LookAt - w/ GT Pose experiments = { "GT-Pose": [R_gt, T_gt], # "Rot+Trans": [R_pred, T_pred], # "Trans+LookAt": [R_pred_test, T_pred_test] } results = {} input_m = torch.stack((mask_priors)) for i in range(len(experiments.keys())): logging.info( f"Input pred shape & max: {input_m.shape}, {input_m.max()}") # The MeshDeformation model will return silhouettes across all view by default mesh_model = MeshDeformationModel(device=device, params=params) R, T = list(experiments.values())[i] experiment_results = mesh_model.run_optimization(input_m, R, T) renders = mesh_model.render_final_mesh((R, T), "predict", input_m.shape[-2:]) mesh_silhouettes = renders["silhouettes"].squeeze(1) mesh_images = renders["images"].squeeze(1) experiment_name = list(experiments.keys())[i] for idx in range(len(mesh_silhouettes)): manager.add_pred( idx, mesh_silhouettes[idx].cpu().numpy(), "silhouette", destination=f"mesh_{experiment_name}", ) manager.add_pred( idx, mesh_images[idx].cpu().numpy(), "phong", destination=f"mesh_{experiment_name}", ) # Calculate chamfer loss: mesh_pred = mesh_model._final_mesh if mesh_type == "dolphin": path = "data/meshes/dolphin/dolphin.obj" mesh_gt = load_objs_as_meshes( [path], create_texture_atlas=False, load_textures=True, device=device, ) # Shapenet Cars elif mesh_type == "shapenet": mesh_info = manager.metadata["mesh_info"] path = os.path.join( params.gt_mesh_path, f"/ShapeNetCorev2/{mesh_info['synset_id']}/{mesh_info['mesh_id']}/models/model_normalized.obj" ) # path = f"data/ShapeNetCorev2/{mesh_info['synset_id']}/{mesh_info['mesh_id']}/models/model_normalized.obj" try: verts, faces, aux = load_obj(path, load_textures=True, create_texture_atlas=True) mesh_gt = Meshes( verts=[verts], faces=[faces.verts_idx], textures=TexturesAtlas(atlas=[aux.texture_atlas]), ).to(device) except: mesh_gt = None print("CANNOT COMPUTE CHAMFER LOSS") if mesh_gt: mesh_pred_compute, mesh_gt_compute = scale_meshes( mesh_pred.clone(), mesh_gt.clone()) pcl_pred = sample_points_from_meshes(mesh_pred_compute, num_samples=5000) pcl_gt = sample_points_from_meshes(mesh_gt_compute, num_samples=5000) chamfer_loss = chamfer_distance(pcl_pred, pcl_gt, point_reduction="mean") print("CHAMFER LOSS: ", chamfer_loss) experiment_results["chamfer_loss"] = ( chamfer_loss[0].cpu().detach().numpy().tolist()) mesh_iou = neg_iou_loss(groundtruth_silhouettes, mesh_silhouettes) experiment_results["mesh_iou"] = mesh_iou.cpu().numpy().tolist() results[experiment_name] = experiment_results manager.add_pred_mesh(mesh_pred, experiment_name) # logging.info(f"Input pred shape & max: {input_m.shape}, {input_m.max()}") # # The MeshDeformation model will return silhouettes across all view by default # # # # experiment_results = models["mesh"].run_optimization(input_m, R_gt, T_gt) # renders = models["mesh"].render_final_mesh( # (R_gt, T_gt), "predict", input_m.shape[-2:] # ) # # mesh_silhouettes = renders["silhouettes"].squeeze(1) # mesh_images = renders["images"].squeeze(1) # experiment_name = params.name # for idx in range(len(mesh_silhouettes)): # manager.add_pred( # idx, # mesh_silhouettes[idx].cpu().numpy(), # "silhouette", # destination=f"mesh_{experiment_name}", # ) # manager.add_pred( # idx, # mesh_images[idx].cpu().numpy(), # "phong", # destination=f"mesh_{experiment_name}", # ) # # # Calculate chamfer loss: # mesh_pred = models["mesh"]._final_mesh # if mesh_type == "dolphin": # path = params.gt_mesh_path # mesh_gt = load_objs_as_meshes( # [path], # create_texture_atlas=False, # load_textures=True, # device=device, # ) # # Shapenet Cars # elif mesh_type == "shapenet": # mesh_info = manager.metadata["mesh_info"] # path = params.gt_mesh_path # try: # verts, faces, aux = load_obj( # path, load_textures=True, create_texture_atlas=True # ) # # mesh_gt = Meshes( # verts=[verts], # faces=[faces.verts_idx], # textures=TexturesAtlas(atlas=[aux.texture_atlas]), # ).to(device) # except: # mesh_gt = None # print("CANNOT COMPUTE CHAMFER LOSS") # if mesh_gt and params.is_real_data: # mesh_pred_compute, mesh_gt_compute = scale_meshes( # mesh_pred.clone(), mesh_gt.clone() # ) # pcl_pred = sample_points_from_meshes( # mesh_pred_compute, num_samples=5000 # ) # pcl_gt = sample_points_from_meshes(mesh_gt_compute, num_samples=5000) # chamfer_loss = chamfer_distance( # pcl_pred, pcl_gt, point_reduction="mean" # ) # print("CHAMFER LOSS: ", chamfer_loss) # experiment_results["chamfer_loss"] = ( # chamfer_loss[0].cpu().detach().numpy().tolist() # ) # # mesh_iou = neg_iou_loss_all(groundtruth_silhouettes, mesh_silhouettes) # # experiment_results["mesh_iou"] = mesh_iou.cpu().numpy().tolist() # # results[experiment_name] = experiment_results # # manager.add_pred_mesh(mesh_pred, experiment_name) seg_iou = neg_iou_loss_all(groundtruth_silhouettes, input_m / 255.0) gt_iou = neg_iou_loss_all(groundtruth_silhouettes, groundtruth_silhouettes) results["mesh_iou"] = mesh_iou.detach().cpu().numpy().tolist() results["seg_iou"] = seg_iou.detach().cpu().numpy().tolist() logging.info(f"Mesh IOU list & results: {mesh_iou}") logging.info(f"Seg IOU list & results: {seg_iou}") logging.info(f"GT IOU list & results: {gt_iou} ") # results["mean_iou"] = IOULoss().forward(groundtruth, mesh_silhouettes).detach().cpu().numpy().tolist() # results["mean_dice"] = DiceCoeffLoss().forward(groundtruth, mesh_silhouettes) manager.set_pred_results(results) manager.close()
def compute_rotation_quanternion(points): pca = PCA(n_components=3) pca.fit(points) return matrix_to_quaternion( torch.from_numpy(pca.components_.astype(np.float32)))