Exemplo n.º 1
0
    def _run_and_print(self,
                       x_world,
                       y,
                       R,
                       T,
                       print_stats,
                       skip_q,
                       check_output=False):
        sol = perspective_n_points.efficient_pnp(x_world,
                                                 y.expand_as(
                                                     x_world[:, :, :2]),
                                                 skip_quadratic_eq=skip_q)

        err_2d = reproj_error(x_world, y, sol.R, sol.T)
        R_est_quat = rotation_conversions.matrix_to_quaternion(sol.R)
        R_quat = rotation_conversions.matrix_to_quaternion(R)

        num_pts = x_world.shape[-2]
        # quadratic part is more stable with fewer points
        num_pts_thresh = 5 if skip_q else 4
        if check_output and num_pts > num_pts_thresh:
            assert_msg = (f"test_perspective_n_points assertion failure for "
                          f"n_points={num_pts}, "
                          f"skip_quadratic={skip_q}, "
                          f"no noise.")

            self.assertClose(err_2d, sol.err_2d, msg=assert_msg)
            self.assertTrue((err_2d < 1e-3).all(), msg=assert_msg)

            def norm_fn(t):
                return t.norm(dim=-1)

            self.assertNormsClose(T,
                                  sol.T[:, None, :],
                                  rtol=4e-3,
                                  norm_fn=norm_fn,
                                  msg=assert_msg)
            self.assertNormsClose(R_quat,
                                  R_est_quat,
                                  rtol=3e-3,
                                  norm_fn=norm_fn,
                                  msg=assert_msg)

        if print_stats:
            torch.set_printoptions(precision=5, sci_mode=False)
            for err_2d, err_3d, R_gt, T_gt in zip(
                    sol.err_2d,
                    sol.err_3d,
                    torch.cat((sol.R, R), dim=-1),
                    torch.stack((sol.T, T[:, 0, :]), dim=-1),
            ):
                print("2D Error: %1.4f" % err_2d.item())
                print("3D Error: %1.4f" % err_3d.item())
                print("R_hat | R_gt\n", R_gt)
                print("T_hat | T_gt\n", T_gt)
Exemplo n.º 2
0
    def test_weighted_perspective_n_points(self, batch_size=16, num_pts=200):
        # instantiate random x_world and y
        y = torch.randn((batch_size, num_pts, 2)).cuda() / 3.0
        x_cam, x_world, R, T = TestPerspectiveNPoints._generate_epnp_test_from_2d(
            y)

        # randomly drop 50% of the rows
        weights = (torch.rand_like(x_world[:, :, 0]) > 0.5).float()

        # make sure we retain at least 6 points for each case
        weights[:, :6] = 1.0

        # fill ignored y with trash to ensure that we get different
        # solution in case the weighting is wrong
        y = y + (1 - weights[:, :, None]) * 100.0

        def norm_fn(t):
            return t.norm(dim=-1)

        for skip_quadratic_eq in (True, False):
            # get the solution for the 0/1 weighted case
            sol = perspective_n_points.efficient_pnp(
                x_world,
                y,
                skip_quadratic_eq=skip_quadratic_eq,
                weights=weights)
            sol_R_quat = rotation_conversions.matrix_to_quaternion(sol.R)
            sol_T = sol.T

            # check that running only on points with non-zero weights ends in the
            # same place as running the 0/1 weighted version
            for i in range(batch_size):
                ok = weights[i] > 0
                x_world_ok = x_world[i, ok][None]
                y_ok = y[i, ok][None]
                sol_ok = perspective_n_points.efficient_pnp(
                    x_world_ok, y_ok, skip_quadratic_eq=False)
                R_est_quat_ok = rotation_conversions.matrix_to_quaternion(
                    sol_ok.R)

                self.assertNormsClose(sol_T[i],
                                      sol_ok.T[0],
                                      rtol=3e-3,
                                      norm_fn=norm_fn)
                self.assertNormsClose(sol_R_quat[i],
                                      R_est_quat_ok[0],
                                      rtol=3e-4,
                                      norm_fn=norm_fn)
Exemplo n.º 3
0
def interpolate_cameras(C, R):
    if torch.isfinite(C).all():
        return C.clone(), R.clone()

    from pytorch3d.transforms import rotation_conversions
    from scipy.interpolate import interp1d

    ok = torch.isfinite(C.mean(1))
    quats = rotation_conversions.matrix_to_quaternion(R)

    n_frames = C.shape[0]
    y = torch.cat((quats, C), dim=1).numpy()
    x = torch.arange(n_frames).float().numpy()
    ok = np.isfinite(y.mean(1))

    fi = interp1d(
        x[ok], y[ok], kind='linear',
        bounds_error=False, axis=0,
        fill_value=(y[ok][0], y[ok][-1])
    )

    y_interp = fi(x)

    i_quats = torch.tensor(y_interp[:, :4]).float()
    i_R = rotation_conversions.quaternion_to_matrix(i_quats)
    i_C = torch.tensor(y_interp[:, 4:]).float()

    return i_C, i_R
 def test_quat_grad_exists(self):
     """Quaternion calculations are differentiable."""
     rotation = random_rotation()
     rotation.requires_grad = True
     modified = quaternion_to_matrix(matrix_to_quaternion(rotation))
     [g] = torch.autograd.grad(modified.sum(), rotation)
     self.assertTrue(torch.isfinite(g).all())
Exemplo n.º 5
0
    def preprocess_poses(cls, poses: tuple):
        """Generates (N, 6) vector of absolute poses
        Args:
            Tuple of batched rotations (N, 3, 3) and translations (N, 3) in Pytorch3d view-to-world coordinates. usually returned from a call to RenderManager._trajectory
            More information about Pytorch3D's coordinate system: https://github.com/facebookresearch/pytorch3d/blob/master/docs/notes/cameras.md

        1. Computes rotation and translation matrices in view-to-world coordinates.
        2. Generates unit quaternion from R and computes log q repr
        3. Normalizes translation according to mean and stdev

        Returns:
            (N, 6) vector: [t1, t2, t3, logq1, logq2, logq3]
        """
        R, T = poses
        cam_wvt = get_world_to_view_transform(R=R, T=T)
        pose_transform = cam_wvt.inverse().get_matrix()
        T = pose_transform[:, 3, :3]
        R = pose_transform[:, :3, :3]

        # Compute pose stats
        std_R, mean_R = torch.std_mean(R)
        std_T, mean_T = torch.std_mean(T)

        q = rc.matrix_to_quaternion(R)
        # q /= torch.norm(q)
        # q *= torch.sign(q[0])  # hemisphere constraint
        # logq = qlog(q)

        T -= mean_T
        T /= std_T

        return torch.cat((T, q), dim=1)
 def test_quaternion_multiplication(self):
     """Quaternion and matrix multiplication are equivalent."""
     a = random_quaternions(15, torch.float64).reshape((3, 5, 4))
     b = random_quaternions(21, torch.float64).reshape((7, 3, 1, 4))
     ab = quaternion_multiply(a, b)
     self.assertEqual(ab.shape, (7, 3, 5, 4))
     a_matrix = quaternion_to_matrix(a)
     b_matrix = quaternion_to_matrix(b)
     ab_matrix = torch.matmul(a_matrix, b_matrix)
     ab_from_matrix = matrix_to_quaternion(ab_matrix)
     self._assert_quaternions_close(ab, ab_from_matrix)
Exemplo n.º 7
0
    def test_matrix_to_quaternion_corner_case(self):
        """Check no bad gradients from sqrt(0)."""
        matrix = torch.eye(3, requires_grad=True)
        target = torch.Tensor([0.984808, 0, 0.174, 0])

        optimizer = torch.optim.Adam([matrix], lr=0.05)
        optimizer.zero_grad()
        q = matrix_to_quaternion(matrix)
        loss = torch.sum((q - target)**2)
        loss.backward()
        optimizer.step()

        self.assertClose(matrix, 0.95 * torch.eye(3))
Exemplo n.º 8
0
def interp_rotation(r1, r2, interp_factor):
    ''' Given two rotation matrices r1 and r2, returns a rotation
    that is interp_factor between them; when factor=0, returns r1, and
    when factor=1, returns r2. Linearly interpolates along the geodesic
    between the rotations by converting the relative rotation to angle-axis
    and scaling the angle by interp_factor. If r1 and r2 pi radians apart,
    the returned rotation axis will be arbitrary. '''
    assert interp_factor >= 0. and interp_factor <= 1.
    # Convert to angle-axis, interpolate angle, convert back.
    # When interp_factor = 0, this return r1. When interp_factor = 1, this
    # returns 1.
    rel = torch.matmul(r2, r1.transpose(0, 1))
    rel_axis_angle = quaternion_to_axis_angle(matrix_to_quaternion(rel))
    # Scaling keeps axis the same, but changes angle.
    scaled_rel_axis_angle = interp_factor * rel_axis_angle
    return torch.matmul(axis_angle_to_matrix(scaled_rel_axis_angle), r1)
    def test_matrix_to_quaternion_corner_case(self):
        """Check no bad gradients from sqrt(0)."""
        matrix = torch.eye(3, requires_grad=True)
        target = torch.Tensor([0.984808, 0, 0.174, 0])

        optimizer = torch.optim.Adam([matrix], lr=0.05)
        optimizer.zero_grad()
        q = matrix_to_quaternion(matrix)
        loss = torch.sum((q - target) ** 2)
        loss.backward()
        optimizer.step()

        self.assertClose(matrix, matrix, msg="Result has non-finite values")
        delta = 1e-2
        self.assertLess(
            matrix.trace(),
            3.0 - delta,
            msg="Identity initialisation unchanged by a gradient step",
        )
    def test_matrix_to_quaternion_by_pi(self):
        # We check that rotations by pi around each of the 26
        # nonzero vectors containing nothing but 0, 1 and -1
        # are mapped to the right quaternions.
        # This is representative across the directions.
        options = [0.0, -1.0, 1.0]
        axes = [
            torch.tensor(vec)
            for vec in itertools.islice(  # exclude [0, 0, 0]
                itertools.product(options, options, options), 1, None
            )
        ]

        axes = torch.nn.functional.normalize(torch.stack(axes), dim=-1)
        # Rotation by pi around unit vector x is given by
        # the matrix 2 x x^T - Id.
        R = 2 * torch.matmul(axes[..., None], axes[..., None, :]) - torch.eye(3)
        quats_hat = matrix_to_quaternion(R)
        R_hat = quaternion_to_matrix(quats_hat)
        self.assertClose(R, R_hat, atol=1e-3)
Exemplo n.º 11
0
 def test_to_quat(self):
     """mtx -> quat -> mtx"""
     data = random_rotations(13, dtype=torch.float64)
     mdata = quaternion_to_matrix(matrix_to_quaternion(data))
     self.assertTrue(torch.allclose(data, mdata))
Exemplo n.º 12
0
 def test_from_quat(self):
     """quat -> mtx -> quat"""
     data = random_quaternions(13, dtype=torch.float64)
     mdata = matrix_to_quaternion(quaternion_to_matrix(data))
     self.assertTrue(torch.allclose(data, mdata))
Exemplo n.º 13
0
def pred_synth(segpose,
               params: Params,
               mesh_type: str = "dolphin",
               device: str = "cuda"):

    if not params.pred_dir and not os.path.exists(params.pred_dir):
        raise FileNotFoundError(
            "Prediction directory has not been set or the file does not exist, please set using cli args or params"
        )
    pred_folders = [
        join(params.pred_dir, f) for f in os.listdir(params.pred_dir)
    ]
    count = 1
    for p in sorted(pred_folders):
        try:
            print(p)
            manager = RenderManager.from_path(p)
            manager.rectify_paths(base_folder=params.pred_dir)
        except FileNotFoundError:
            continue
        # Run Silhouette Prediction Network
        logging.info(f"Starting mask predictions")
        mask_priors = []
        R_pred, T_pred = [], []
        q_loss, t_loss = 0, 0
        # Collect Translation stats
        R_gt, T_gt = manager._trajectory
        poses_gt = EvMaskPoseDataset.preprocess_poses(manager._trajectory)
        std_T, mean_T = torch.std_mean(T_gt)
        for idx in range(len(manager)):
            try:
                ev_frame = manager.get_event_frame(idx)
            except Exception as e:
                print(e)
                break
            mask_pred, pose_pred = predict_segpose(segpose, ev_frame,
                                                   params.threshold_conf,
                                                   params.img_size)
            # mask_pred = smooth_predicted_mask(mask_pred)
            manager.add_pred(idx, mask_pred, "silhouette")
            mask_priors.append(torch.from_numpy(mask_pred))

            # Make qexp a torch function
            # q_pred = qexp(pose_pred[:, 3:])
            # q_targ = qexp(poses_gt[idx, 3:].unsqueeze(0))
            ####  SHOULD THIS BE NORMALIZED ??
            q_pred = pose_pred[:, 3:]
            q_targ = poses_gt[idx, 3:]

            q_pred_unit = q_pred / torch.norm(q_pred)
            q_targ_unit = q_targ / torch.norm(q_targ)
            # print("learnt: ", q_pred_unit, q_targ_unit)

            t_pred = pose_pred[:, :3] * std_T + mean_T
            t_targ = poses_gt[idx, :3] * std_T + mean_T
            T_pred.append(t_pred)

            q_loss += quaternion_angular_error(q_pred_unit, q_targ_unit)
            t_loss += t_error(t_pred, t_targ)

            r_pred = rc.quaternion_to_matrix(q_pred).unsqueeze(0)
            R_pred.append(r_pred.squeeze(0))

        q_loss_mean = q_loss / (idx + 1)
        t_loss_mean = t_loss / (idx + 1)

        # Convert R,T to world-to-view transforms --> Pytorch3d convention for the :

        R_pred_abs = torch.cat(R_pred)
        T_pred_abs = torch.cat(T_pred)
        # Take inverse of view-to-world (output of net) to obtain w2v
        wtv_trans = (get_world_to_view_transform(
            R=R_pred_abs, T=T_pred_abs).inverse().get_matrix())
        T_pred = wtv_trans[:, 3, :3]
        R_pred = wtv_trans[:, :3, :3]
        R_pred_test = look_at_rotation(T_pred_abs)
        T_pred_test = -torch.bmm(R_pred_test.transpose(1, 2),
                                 T_pred_abs[:, :, None])[:, :, 0]
        # Convert back to view-to-world to get absolute
        vtw_trans = (get_world_to_view_transform(
            R=R_pred_test, T=T_pred_test).inverse().get_matrix())
        T_pred_trans = vtw_trans[:, 3, :3]
        R_pred_trans = vtw_trans[:, :3, :3]

        # Calc pose error for this:
        q_loss_mean_test = 0
        t_loss_mean_test = 0
        for idx in range(len(R_pred_test)):
            q_pred_trans = rc.matrix_to_quaternion(R_pred_trans[idx]).squeeze()
            q_targ = poses_gt[idx, 3:]
            q_targ_unit = q_targ / torch.norm(q_targ)
            # print("look: ", q_test, q_targ)
            q_loss_mean_test += quaternion_angular_error(
                q_pred_trans, q_targ_unit)
            t_targ = poses_gt[idx, :3] * std_T + mean_T
            t_loss_mean_test += t_error(T_pred_trans[idx], t_targ)
        q_loss_mean_test /= idx + 1
        t_loss_mean_test /= idx + 1

        logging.info(
            f"Mean Translation Error: {t_loss_mean}; Mean Rotation Error: {q_loss_mean}"
        )
        logging.info(
            f"Mean Translation Error: {t_loss_mean_test}; Mean Rotation Error: {q_loss_mean_test}"
        )

        # Plot estimated cameras
        logging.info(f"Plotting pose map")
        idx = random.sample(range(len(R_gt)), k=2)
        pose_plot = plot_cams_from_poses(
            (R_gt[idx], T_gt[idx]), (R_pred[idx], T_pred[idx]), params.device)
        pose_plot_test = plot_cams_from_poses(
            (R_gt[idx], T_gt[idx]), (R_pred_test[idx], T_pred_test[idx]),
            params.device)
        manager.add_pose_plot(pose_plot, "rot+trans")
        manager.add_pose_plot(pose_plot_test, "trans")

        count += 1
        groundtruth_silhouettes = manager._images("silhouette") / 255.0
        print(groundtruth_silhouettes.shape)
        print(torch.stack((mask_priors)).shape)
        seg_iou = neg_iou_loss(groundtruth_silhouettes,
                               torch.stack((mask_priors)) / 255.0)
        print("Seg IoU", seg_iou)

        # RUN MESH DEFORMATION

        # RUN MESH DEFORMATION
        # Run it 3 times: w/ Rot+Trans - w/ Trans+LookAt - w/ GT Pose
        experiments = {
            "GT-Pose": [R_gt, T_gt],
            # "Rot+Trans": [R_pred, T_pred],
            # "Trans+LookAt": [R_pred_test, T_pred_test]
        }

        results = {}
        input_m = torch.stack((mask_priors))

        for i in range(len(experiments.keys())):

            logging.info(
                f"Input pred shape & max: {input_m.shape}, {input_m.max()}")
            # The MeshDeformation model will return silhouettes across all view by default

            mesh_model = MeshDeformationModel(device=device, params=params)

            R, T = list(experiments.values())[i]
            experiment_results = mesh_model.run_optimization(input_m, R, T)
            renders = mesh_model.render_final_mesh((R, T), "predict",
                                                   input_m.shape[-2:])

            mesh_silhouettes = renders["silhouettes"].squeeze(1)
            mesh_images = renders["images"].squeeze(1)
            experiment_name = list(experiments.keys())[i]
            for idx in range(len(mesh_silhouettes)):
                manager.add_pred(
                    idx,
                    mesh_silhouettes[idx].cpu().numpy(),
                    "silhouette",
                    destination=f"mesh_{experiment_name}",
                )
                manager.add_pred(
                    idx,
                    mesh_images[idx].cpu().numpy(),
                    "phong",
                    destination=f"mesh_{experiment_name}",
                )

            # Calculate chamfer loss:
            mesh_pred = mesh_model._final_mesh
            if mesh_type == "dolphin":
                path = "data/meshes/dolphin/dolphin.obj"
                mesh_gt = load_objs_as_meshes(
                    [path],
                    create_texture_atlas=False,
                    load_textures=True,
                    device=device,
                )
            # Shapenet Cars
            elif mesh_type == "shapenet":
                mesh_info = manager.metadata["mesh_info"]
                path = os.path.join(
                    params.gt_mesh_path,
                    f"/ShapeNetCorev2/{mesh_info['synset_id']}/{mesh_info['mesh_id']}/models/model_normalized.obj"
                )
                # path = f"data/ShapeNetCorev2/{mesh_info['synset_id']}/{mesh_info['mesh_id']}/models/model_normalized.obj"
                try:
                    verts, faces, aux = load_obj(path,
                                                 load_textures=True,
                                                 create_texture_atlas=True)

                    mesh_gt = Meshes(
                        verts=[verts],
                        faces=[faces.verts_idx],
                        textures=TexturesAtlas(atlas=[aux.texture_atlas]),
                    ).to(device)
                except:
                    mesh_gt = None
                    print("CANNOT COMPUTE CHAMFER LOSS")
            if mesh_gt:
                mesh_pred_compute, mesh_gt_compute = scale_meshes(
                    mesh_pred.clone(), mesh_gt.clone())
                pcl_pred = sample_points_from_meshes(mesh_pred_compute,
                                                     num_samples=5000)
                pcl_gt = sample_points_from_meshes(mesh_gt_compute,
                                                   num_samples=5000)
                chamfer_loss = chamfer_distance(pcl_pred,
                                                pcl_gt,
                                                point_reduction="mean")
                print("CHAMFER LOSS: ", chamfer_loss)
                experiment_results["chamfer_loss"] = (
                    chamfer_loss[0].cpu().detach().numpy().tolist())

            mesh_iou = neg_iou_loss(groundtruth_silhouettes, mesh_silhouettes)

            experiment_results["mesh_iou"] = mesh_iou.cpu().numpy().tolist()

            results[experiment_name] = experiment_results

            manager.add_pred_mesh(mesh_pred, experiment_name)
        # logging.info(f"Input pred shape & max: {input_m.shape}, {input_m.max()}")
        # # The MeshDeformation model will return silhouettes across all view by default
        #
        #
        #
        # experiment_results = models["mesh"].run_optimization(input_m, R_gt, T_gt)
        # renders = models["mesh"].render_final_mesh(
        #     (R_gt, T_gt), "predict", input_m.shape[-2:]
        # )
        #
        # mesh_silhouettes = renders["silhouettes"].squeeze(1)
        # mesh_images = renders["images"].squeeze(1)
        # experiment_name = params.name
        # for idx in range(len(mesh_silhouettes)):
        #     manager.add_pred(
        #         idx,
        #         mesh_silhouettes[idx].cpu().numpy(),
        #         "silhouette",
        #         destination=f"mesh_{experiment_name}",
        #     )
        #     manager.add_pred(
        #         idx,
        #         mesh_images[idx].cpu().numpy(),
        #         "phong",
        #         destination=f"mesh_{experiment_name}",
        #     )
        #
        # # Calculate chamfer loss:
        # mesh_pred = models["mesh"]._final_mesh
        # if mesh_type == "dolphin":
        #     path = params.gt_mesh_path
        #     mesh_gt = load_objs_as_meshes(
        #         [path],
        #         create_texture_atlas=False,
        #         load_textures=True,
        #         device=device,
        #     )
        # # Shapenet Cars
        # elif mesh_type == "shapenet":
        #     mesh_info = manager.metadata["mesh_info"]
        #     path = params.gt_mesh_path
        #     try:
        #         verts, faces, aux = load_obj(
        #             path, load_textures=True, create_texture_atlas=True
        #         )
        #
        #         mesh_gt = Meshes(
        #             verts=[verts],
        #             faces=[faces.verts_idx],
        #             textures=TexturesAtlas(atlas=[aux.texture_atlas]),
        #         ).to(device)
        #     except:
        #         mesh_gt = None
        #         print("CANNOT COMPUTE CHAMFER LOSS")
        # if mesh_gt and params.is_real_data:
        #     mesh_pred_compute, mesh_gt_compute = scale_meshes(
        #         mesh_pred.clone(), mesh_gt.clone()
        #     )
        #     pcl_pred = sample_points_from_meshes(
        #         mesh_pred_compute, num_samples=5000
        #     )
        #     pcl_gt = sample_points_from_meshes(mesh_gt_compute, num_samples=5000)
        #     chamfer_loss = chamfer_distance(
        #         pcl_pred, pcl_gt, point_reduction="mean"
        #     )
        #     print("CHAMFER LOSS: ", chamfer_loss)
        #     experiment_results["chamfer_loss"] = (
        #         chamfer_loss[0].cpu().detach().numpy().tolist()
        #     )
        #
        # mesh_iou = neg_iou_loss_all(groundtruth_silhouettes, mesh_silhouettes)
        #
        # experiment_results["mesh_iou"] = mesh_iou.cpu().numpy().tolist()
        #
        # results[experiment_name] = experiment_results
        #
        # manager.add_pred_mesh(mesh_pred, experiment_name)

        seg_iou = neg_iou_loss_all(groundtruth_silhouettes, input_m / 255.0)
        gt_iou = neg_iou_loss_all(groundtruth_silhouettes,
                                  groundtruth_silhouettes)

        results["mesh_iou"] = mesh_iou.detach().cpu().numpy().tolist()
        results["seg_iou"] = seg_iou.detach().cpu().numpy().tolist()
        logging.info(f"Mesh IOU list & results: {mesh_iou}")
        logging.info(f"Seg IOU list & results: {seg_iou}")
        logging.info(f"GT IOU list & results: {gt_iou} ")

        # results["mean_iou"] = IOULoss().forward(groundtruth, mesh_silhouettes).detach().cpu().numpy().tolist()
        # results["mean_dice"] = DiceCoeffLoss().forward(groundtruth, mesh_silhouettes)

        manager.set_pred_results(results)
        manager.close()
Exemplo n.º 14
0
def compute_rotation_quanternion(points):
    pca = PCA(n_components=3)
    pca.fit(points)
    return matrix_to_quaternion(
        torch.from_numpy(pca.components_.astype(np.float32)))