Example #1
0
    def test_bad_so3_input_value_err(self):
        """
        Tests whether `so3_exponential_map` and `so3_log_map` correctly return
        a ValueError if called with an argument of incorrect shape or, in case
        of `so3_exponential_map`, unexpected trace.
        """
        device = torch.device("cuda:0")
        log_rot = torch.randn(size=[5, 4], device=device)
        with self.assertRaises(ValueError) as err:
            so3_exponential_map(log_rot)
        self.assertTrue(
            "Input tensor shape has to be Nx3." in str(err.exception))

        rot = torch.randn(size=[5, 3, 5], device=device)
        with self.assertRaises(ValueError) as err:
            so3_log_map(rot)
        self.assertTrue(
            "Input has to be a batch of 3x3 Tensors." in str(err.exception))

        # trace of rot definitely bigger than 3 or smaller than -1
        rot = torch.cat((
            torch.rand(size=[5, 3, 3], device=device) + 4.0,
            torch.rand(size=[5, 3, 3], device=device) - 3.0,
        ))
        with self.assertRaises(ValueError) as err:
            so3_log_map(rot)
        self.assertTrue(
            "A matrix has trace outside valid range [-1-eps,3+eps]." in str(
                err.exception))
Example #2
0
 def init_random_cameras(cam_type: CamerasBase, batch_size: int):
     cam_params = {}
     T = torch.randn(batch_size, 3) * 0.03
     T[:, 2] = 4
     R = so3_exponential_map(torch.randn(batch_size, 3) * 3.0)
     cam_params = {'R': R, 'T': T}
     if cam_type in (OpenGLPerspectiveCameras, OpenGLOrthographicCameras):
         cam_params['znear'] = torch.rand(batch_size) * 10 + 0.1
         cam_params['zfar'] = (
             torch.rand(batch_size) * 4 + 1 + cam_params['znear']
         )
         if cam_type == OpenGLPerspectiveCameras:
             cam_params['fov'] = torch.rand(batch_size) * 60 + 30
             cam_params['aspect_ratio'] = torch.rand(batch_size) * 0.5 + 0.5
         else:
             cam_params['top'] = torch.rand(batch_size) * 0.2 + 0.9
             cam_params['bottom'] = -(torch.rand(batch_size)) * 0.2 - 0.9
             cam_params['left'] = -(torch.rand(batch_size)) * 0.2 - 0.9
             cam_params['right'] = torch.rand(batch_size) * 0.2 + 0.9
     elif cam_type in (SfMOrthographicCameras, SfMPerspectiveCameras):
         cam_params['focal_length'] = torch.rand(batch_size) * 10 + 0.1
         cam_params['principal_point'] = torch.randn((batch_size, 2))
     else:
         raise ValueError(str(cam_type))
     return cam_type(**cam_params)
Example #3
0
    def calc(self, log_R, T):
        # Render the image using the updated camera position. Based on the new position of the
        # camer we calculate the rotation and translation matrices
        R = so3_exponential_map(log_R)
        image = self.renderer.render(self.meshes, R, T)

        loss = torch.sum((image[..., 3] - self.image_ref) ** 2)
        return loss, image
Example #4
0
 def test_so3_log_to_exp_to_log(self, batch_size: int = 100):
     """
     Check that `so3_log_map(so3_exponential_map(log_rot))==log_rot` for
     a randomly generated batch of rotation matrix logarithms `log_rot`.
     """
     log_rot = TestSO3.init_log_rot(batch_size=batch_size)
     log_rot_ = so3_log_map(so3_exponential_map(log_rot))
     max_df = (log_rot - log_rot_).abs().max()
     self.assertAlmostEqual(float(max_df), 0.0, 4)
Example #5
0
 def test_determinant(self):
     """
     Tests whether the determinants of 3x3 rotation matrices produced
     by `so3_exponential_map` are (almost) equal to 1.
     """
     log_rot = TestSO3.init_log_rot(batch_size=30)
     Rs = so3_exponential_map(log_rot)
     dets = torch.det(Rs)
     self.assertClose(dets, torch.ones_like(dets), atol=1e-4)
Example #6
0
    def test_inverse(self, batch_size=5):
        device = torch.device("cuda:0")

        # generate a random chain of transforms
        for _ in range(10):  # 10 different tries

            # list of transform matrices
            ts = []

            for i in range(10):
                choice = float(torch.rand(1))
                if choice <= 1.0 / 3.0:
                    t_ = Translate(
                        torch.randn((batch_size, 3),
                                    dtype=torch.float32,
                                    device=device),
                        device=device,
                    )
                elif choice <= 2.0 / 3.0:
                    t_ = Rotate(
                        so3_exponential_map(
                            torch.randn(
                                (batch_size, 3),
                                dtype=torch.float32,
                                device=device,
                            )),
                        device=device,
                    )
                else:
                    rand_t = torch.randn((batch_size, 3),
                                         dtype=torch.float32,
                                         device=device)
                    rand_t = rand_t.sign() * torch.clamp(rand_t.abs(), 0.2)
                    t_ = Scale(rand_t, device=device)
                ts.append(t_._matrix.clone())

                if i == 0:
                    t = t_
                else:
                    t = t.compose(t_)

            # generate the inverse transformation in several possible ways
            m1 = t.inverse(invert_composed=True).get_matrix()
            m2 = t.inverse(invert_composed=True)._matrix
            m3 = t.inverse(invert_composed=False).get_matrix()
            m4 = t.get_matrix().inverse()

            # compute the inverse explicitly ...
            m5 = torch.eye(4, dtype=torch.float32, device=device)
            m5 = m5[None].repeat(batch_size, 1, 1)
            for t_ in ts:
                m5 = torch.bmm(torch.inverse(t_), m5)

            # assert all same
            for m in (m1, m2, m3, m4):
                self.assertTrue(torch.allclose(m, m5, atol=1e-3))
Example #7
0
 def test_so3_exp_to_log_to_exp(self, batch_size: int = 100):
     """
     Check that `so3_exponential_map(so3_log_map(R))==R` for
     a batch of randomly generated rotation matrices `R`.
     """
     rot = TestSO3.init_rot(batch_size=batch_size)
     rot_ = so3_exponential_map(so3_log_map(rot, eps=1e-8), eps=1e-8)
     angles = so3_relative_angle(rot, rot_)
     # TODO: a lot of precision lost here ...
     self.assertClose(angles, torch.zeros_like(angles), atol=0.1)
Example #8
0
 def test_so3_log_to_exp_to_log(self, batch_size: int = 100):
     """
     Check that `so3_log_map(so3_exponential_map(log_rot))==log_rot` for
     a randomly generated batch of rotation matrix logarithms `log_rot`.
     """
     log_rot = TestSO3.init_log_rot(batch_size=batch_size)
     # check also the singular cases where rot. angle = 0
     log_rot[:1] = 0
     log_rot_ = so3_log_map(so3_exponential_map(log_rot))
     self.assertClose(log_rot, log_rot_, atol=1e-4)
Example #9
0
 def test_so3_log_to_exp_to_log_to_exp(self, batch_size: int = 100):
     """
     Check that
     `so3_exponential_map(so3_log_map(so3_exponential_map(log_rot)))
     == so3_exponential_map(log_rot)`
     for a randomly generated batch of rotation matrix logarithms `log_rot`.
     Unlike `test_so3_log_to_exp_to_log`, this test checks the
     correctness of converting a `log_rot` which contains values > math.pi.
     """
     log_rot = 2.0 * TestSO3.init_log_rot(batch_size=batch_size)
     # check also the singular cases where rot. angle = {0, pi, 2pi, 3pi}
     log_rot[:3] = 0
     log_rot[1, 0] = math.pi
     log_rot[2, 0] = 2.0 * math.pi
     log_rot[3, 0] = 3.0 * math.pi
     rot = so3_exponential_map(log_rot, eps=1e-8)
     rot_ = so3_exponential_map(so3_log_map(rot, eps=1e-8), eps=1e-8)
     angles = so3_relative_angle(rot, rot_)
     self.assertClose(angles, torch.zeros_like(angles), atol=0.01)
Example #10
0
 def test_determinant(self):
     """
     Tests whether the determinants of 3x3 rotation matrices produced
     by `so3_exponential_map` are (almost) equal to 1.
     """
     log_rot = TestSO3.init_log_rot(batch_size=30)
     Rs = so3_exponential_map(log_rot)
     for R in Rs:
         det = np.linalg.det(R.cpu().numpy())
         self.assertAlmostEqual(float(det), 1.0, 5)
def init_uniform_y_rotations(batch_size: int = 10):
    """
    Generate a batch of `batch_size` 3x3 rotation matrices around y-axis
    whose angles are uniformly distributed between 0 and 2 pi.
    """
    device = torch.device("cuda:0")
    axis = torch.tensor([0.0, 1.0, 0.0], device=device, dtype=torch.float32)
    angles = torch.linspace(0, 2.0 * np.pi, batch_size + 1, device=device)
    angles = angles[:batch_size]
    log_rots = axis[None, :] * angles[:, None]
    R = so3_exponential_map(log_rots)
    return R
Example #12
0
 def test_inverse(self, batch_size=5):
     device = torch.device("cuda:0")
     log_rot = torch.randn((batch_size, 3),
                           dtype=torch.float32,
                           device=device)
     R = so3_exponential_map(log_rot)
     t = Rotate(R)
     im = t.inverse()._matrix
     im_2 = t._matrix.inverse()
     im_comp = t.get_matrix().inverse()
     self.assertTrue(torch.allclose(im, im_comp, atol=1e-4))
     self.assertTrue(torch.allclose(im, im_2, atol=1e-4))
Example #13
0
 def test_so3_exp_to_log_to_exp(self, batch_size: int = 100):
     """
     Check that `so3_exponential_map(so3_log_map(R))==R` for
     a batch of randomly generated rotation matrices `R`.
     """
     rot = TestSO3.init_rot(batch_size=batch_size)
     rot_ = so3_exponential_map(so3_log_map(rot))
     angles = so3_relative_angle(rot, rot_)
     max_angle = angles.max()
     # a lot of precision lost here :(
     # TODO: fix this test??
     self.assertTrue(np.allclose(float(max_angle), 0.0, atol=0.1))
Example #14
0
 def test_rotate(self):
     R = so3_exponential_map(torch.randn((1, 3)))
     t = Transform3d().rotate(R)
     points = torch.tensor([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0],
                            [0.5, 0.5, 0.0]]).view(1, 3, 3)
     normals = torch.tensor([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0],
                             [1.0, 1.0, 0.0]]).view(1, 3, 3)
     points_out = t.transform_points(points)
     normals_out = t.transform_normals(normals)
     points_out_expected = torch.bmm(points, R)
     normals_out_expected = torch.bmm(normals, R)
     self.assertTrue(torch.allclose(points_out, points_out_expected))
     self.assertTrue(torch.allclose(normals_out, normals_out_expected))
Example #15
0
 def test_so3_exp_singularity(self, batch_size: int = 100):
     """
     Tests whether the `so3_exponential_map` is robust to the input vectors
     the norms of which are close to the numerically unstable region
     (vectors with low l2-norms).
     """
     # generate random log-rotations with a tiny angle
     log_rot = TestSO3.init_log_rot(batch_size=batch_size)
     log_rot_small = log_rot * 1e-6
     R = so3_exponential_map(log_rot_small)
     # tests whether all outputs are finite
     R_sum = float(R.sum())
     self.assertEqual(R_sum, R_sum)
Example #16
0
 def test_get_camera_center(self, batch_size=10):
     T = torch.randn(batch_size, 3)
     R = so3_exponential_map(torch.randn(batch_size, 3) * 3.0)
     for cam_type in (
             OpenGLPerspectiveCameras,
             OpenGLOrthographicCameras,
             SfMOrthographicCameras,
             SfMPerspectiveCameras,
     ):
         cam = cam_type(R=R, T=T)
         C = cam.get_camera_center()
         C_ = -torch.bmm(R, T[:, :, None])[:, :, 0]
         self.assertTrue(torch.allclose(C, C_, atol=1e-05))
Example #17
0
    def test_blender_camera(self):
        """
        Test BlenderCamera.
        """
        # Test get_world_to_view_transform.
        T = torch.randn(10, 3)
        R = so3_exponential_map(torch.randn(10, 3) * 3.0)
        RT = get_world_to_view_transform(R=R, T=T)
        cam = BlenderCamera(R=R, T=T)
        RT_class = cam.get_world_to_view_transform()
        self.assertTrue(torch.allclose(RT.get_matrix(), RT_class.get_matrix()))
        self.assertTrue(isinstance(RT, Transform3d))

        # Test getting camera center.
        C = cam.get_camera_center()
        C_ = -torch.bmm(R, T[:, :, None])[:, :, 0]
        self.assertTrue(torch.allclose(C, C_, atol=1e-05))
Example #18
0
def init_random_cameras(cam_type: typing.Type[CamerasBase],
                        batch_size: int,
                        random_z: bool = False):
    cam_params = {}
    T = torch.randn(batch_size, 3) * 0.03
    if not random_z:
        T[:, 2] = 4
    R = so3_exponential_map(torch.randn(batch_size, 3) * 3.0)
    cam_params = {"R": R, "T": T}
    if cam_type in (OpenGLPerspectiveCameras, OpenGLOrthographicCameras):
        cam_params["znear"] = torch.rand(batch_size) * 10 + 0.1
        cam_params[
            "zfar"] = torch.rand(batch_size) * 4 + 1 + cam_params["znear"]
        if cam_type == OpenGLPerspectiveCameras:
            cam_params["fov"] = torch.rand(batch_size) * 60 + 30
            cam_params["aspect_ratio"] = torch.rand(batch_size) * 0.5 + 0.5
        else:
            cam_params["top"] = torch.rand(batch_size) * 0.2 + 0.9
            cam_params["bottom"] = -(torch.rand(batch_size)) * 0.2 - 0.9
            cam_params["left"] = -(torch.rand(batch_size)) * 0.2 - 0.9
            cam_params["right"] = torch.rand(batch_size) * 0.2 + 0.9
    elif cam_type in (FoVPerspectiveCameras, FoVOrthographicCameras):
        cam_params["znear"] = torch.rand(batch_size) * 10 + 0.1
        cam_params[
            "zfar"] = torch.rand(batch_size) * 4 + 1 + cam_params["znear"]
        if cam_type == FoVPerspectiveCameras:
            cam_params["fov"] = torch.rand(batch_size) * 60 + 30
            cam_params["aspect_ratio"] = torch.rand(batch_size) * 0.5 + 0.5
        else:
            cam_params["max_y"] = torch.rand(batch_size) * 0.2 + 0.9
            cam_params["min_y"] = -(torch.rand(batch_size)) * 0.2 - 0.9
            cam_params["min_x"] = -(torch.rand(batch_size)) * 0.2 - 0.9
            cam_params["max_x"] = torch.rand(batch_size) * 0.2 + 0.9
    elif cam_type in (
            SfMOrthographicCameras,
            SfMPerspectiveCameras,
            OrthographicCameras,
            PerspectiveCameras,
    ):
        cam_params["focal_length"] = torch.rand(batch_size) * 10 + 0.1
        cam_params["principal_point"] = torch.randn((batch_size, 2))

    else:
        raise ValueError(str(cam_type))
    return cam_type(**cam_params)
Example #19
0
    def init_equiv_cameras_ndc_screen(cam_type: CamerasBase, batch_size: int):
        T = torch.randn(batch_size, 3) * 0.03
        T[:, 2] = 4
        R = so3_exponential_map(torch.randn(batch_size, 3) * 3.0)
        screen_cam_params = {"R": R, "T": T}
        ndc_cam_params = {"R": R, "T": T}
        if cam_type in (OrthographicCameras, PerspectiveCameras):
            ndc_cam_params["focal_length"] = torch.rand((batch_size, 2)) * 3.0
            ndc_cam_params["principal_point"] = torch.randn((batch_size, 2))

            image_size = torch.randint(low=2, high=64, size=(batch_size, 2))
            screen_cam_params["image_size"] = image_size
            screen_cam_params["focal_length"] = (
                ndc_cam_params["focal_length"] * image_size / 2.0)
            screen_cam_params["principal_point"] = (
                (1.0 - ndc_cam_params["principal_point"]) * image_size / 2.0)
        else:
            raise ValueError(str(cam_type))
        return cam_type(**ndc_cam_params), cam_type(**screen_cam_params)
Example #20
0
def camera_calibration(flamelayer, target_silhouette, log_R, T, optimizer, renderer):
    '''
    Fit FLAME to 2D landmarks
    :param flamelayer           Flame parametric model
    :param scale                Camera scale parameter (weak prespective camera)
    :param target_2d_lmks:      target 2D landmarks provided as (num_lmks x 3) matrix
    :return: The mesh vertices and the weak prespective camera parameter (scale)
    '''
    # torch_target_2d_lmks = torch.from_numpy(target_2d_lmks).cuda()
    # factor = max(max(target_2d_lmks[:,0]) - min(target_2d_lmks[:,0]),max(target_2d_lmks[:,1]) - min(target_2d_lmks[:,1]))

    # def image_fit_loss(landmarks_3D):
    #     landmarks_2D = torch_project_points_weak_perspective(landmarks_3D, scale)
    #     return flamelayer.weights['lmk']*torch.sum(torch.sub(landmarks_2D,torch_target_2d_lmks)**2) / (factor ** 2)

    # Set the cuda device
    device = torch.device("cuda:0")

    verts, _, _ = flamelayer()
    verts = verts.detach()
    # Initialize each vertex to be white in color.
    verts_rgb = torch.ones_like(verts)[None]  # (1, V, 3)
    textures = Textures(verts_rgb=verts_rgb.to(device))
    faces = torch.tensor(np.int32(flamelayer.faces), dtype=torch.long).cuda()

    my_mesh = Meshes(
        verts=[verts.to(device)],
        faces=[faces.to(device)],
        textures=textures
    )

    silhouette_err = SilhouetteErr(my_mesh, renderer, target_silhouette)

    def fit_closure():
        if torch.is_grad_enabled():
            optimizer.zero_grad()
        loss, sil = silhouette_err.calc(log_R, T)

        obj = loss
        # print(loss)
        # print('cam pos', cam_pos)
        if obj.requires_grad:
            obj.backward()
        return obj

    def log_obj(str):
        if FIT_2D_DEBUG_MODE:
            vertices, landmarks_3D, flame_regularizer_loss = flamelayer()
            # print (str + ' obj = ', image_fit_loss(landmarks_3D))

    def log(str):
        if FIT_2D_DEBUG_MODE:
            print(str)

    loss, sil = silhouette_err.calc(log_R, T)
    sil1 = sil[..., 3].detach().squeeze().cpu().numpy()
    sil2 = silhouette_err.image_ref.detach().cpu().numpy()
    plt.subplot(121)
    plt.imshow(sil1)
    plt.subplot(122)
    plt.imshow(sil2)
    plt.show()
    log('Optimizing rigid transformation')
    log_obj('Before optimization obj')
    optimizer.step(fit_closure)
    log_obj('After optimization obj')
    loss, sil = silhouette_err.calc(log_R, T)
    sil1 = sil[..., 3].detach().squeeze().cpu().numpy()
    sil2 = silhouette_err.image_ref.detach().cpu().numpy()
    plt.subplot(121)
    plt.imshow(sil1)
    plt.subplot(122)
    plt.imshow(sil2)
    plt.show()
    return so3_exponential_map(log_R), T
Example #21
0
 def compute_rots():
     so3_exponential_map(log_rot)
     torch.cuda.synchronize()
Example #22
0
def main(args):
    # set for reproducibility
    torch.manual_seed(42)
    if args.dtype == "float":
        args.dtype = torch.float32
    elif args.dtype == "double":
        args.dtype = torch.float64

    # ## 1. Set up Cameras and load ground truth positions

    # load the SE3 graph of relative/absolute camera positions
    if (args.input_folder / "images.bin").isfile():
        ext = '.bin'
    elif (args.input_folder / "images.txt").isfile():
        ext = '.txt'
    else:
        print('error')
        return
    cameras, images, points3D = read_model(args.input_folder, ext)

    images_df = pd.DataFrame.from_dict(images, orient="index").set_index("id")
    cameras_df = pd.DataFrame.from_dict(cameras,
                                        orient="index").set_index("id")
    points_df = pd.DataFrame.from_dict(points3D,
                                       orient="index").set_index("id")
    print(points_df)
    print(images_df)
    print(cameras_df)

    ref_pointcloud = PyntCloud.from_file(args.ply)
    ref_pointcloud = torch.from_numpy(ref_pointcloud.xyz).to(device,
                                                             dtype=args.dtype)

    points_3d = np.stack(points_df["xyz"].values)
    points_3d = torch.from_numpy(points_3d).to(device, dtype=args.dtype)

    cameras_R = np.stack(
        [qvec2rotmat(q) for _, q in images_df["qvec"].iteritems()])
    cameras_R = torch.from_numpy(cameras_R).to(device,
                                               dtype=args.dtype).transpose(
                                                   1, 2)

    cameras_T = torch.from_numpy(np.stack(images_df["tvec"].values)).to(
        device, dtype=args.dtype)

    cameras_params = torch.from_numpy(np.stack(
        cameras_df["params"].values)).to(device, dtype=args.dtype)
    cameras_params = cameras_params[:, :4]
    print(cameras_params)

    # Constructu visibility map, True at (frame, point) if point is visible by frame, False otherwise
    # Thus, we can ignore reprojection errors for invisible points
    visibility = np.full((cameras_R.shape[0], points_3d.shape[0]), False)
    visibility = pd.DataFrame(visibility,
                              index=images_df.index,
                              columns=points_df.index)

    points_2D_gt = []
    for idx, (pts_ids, xy) in images_df[["point3D_ids", "xys"]].iterrows():
        pts_ids_clean = pts_ids[pts_ids != -1]
        pts_2D = pd.DataFrame(xy[pts_ids != -1], index=pts_ids_clean)
        pts_2D = pts_2D[~pts_2D.index.duplicated(keep=False)].reindex(
            points_df.index).dropna()
        points_2D_gt.append(pts_2D.values)
        visibility.loc[idx, pts_2D.index] = True

    print(visibility)

    visibility = torch.from_numpy(visibility.values).to(device)
    eps = 1e-3
    # Visibility map is very sparse. So we can use Pytorch3d's function to reduce points_2D size
    # to (num_frames, max points seen by frame)
    points_2D_gt = list_to_padded([torch.from_numpy(p) for p in points_2D_gt],
                                  pad_value=eps).to(device, dtype=args.dtype)
    print(points_2D_gt)

    cameras_df["raw_id"] = np.arange(len(cameras_df))
    cameras_id_per_image = torch.from_numpy(
        cameras_df["raw_id"][images_df["camera_id"]].values).to(device)
    # the number of absolute camera positions
    N = len(images_df)
    nonzer = (points_2D_gt != eps).all(dim=-1)

    # print(padded)
    # print(points_2D_gt, points_2D_gt.shape)

    # ## 2. Define optimization functions
    #
    # ### Relative cameras and camera distance
    # We now define two functions crucial for the optimization.
    #
    # **`calc_camera_distance`** compares a pair of cameras.
    # This function is important as it defines the loss that we are minimizing.
    # The method utilizes the `so3_relative_angle` function from the SO3 API.
    #
    # **`get_relative_camera`** computes the parameters of a relative camera
    # that maps between a pair of absolute cameras. Here we utilize the `compose`
    # and `inverse` class methods from the PyTorch3D Transforms API.

    def calc_camera_distance(cam_1, cam_2):
        """
        Calculates the divergence of a batch of pairs of cameras cam_1, cam_2.
        The distance is composed of the cosine of the relative angle between
        the rotation components of the camera extrinsics and the l2 distance
        between the translation vectors.
        """
        # rotation distance
        R_distance = (
            1. - so3_relative_angle(cam_1.R, cam_2.R, cos_angle=True)).mean()
        # translation distance
        T_distance = ((cam_1.T - cam_2.T)**2).sum(1).mean()
        # the final distance is the sum
        return R_distance + T_distance

    # ## 3. Optimization
    # Finally, we start the optimization of the absolute cameras.
    #
    # We use SGD with momentum and optimize over `log_R_absolute` and `T_absolute`.
    #
    # As mentioned earlier, `log_R_absolute` is the axis angle representation of the
    # rotation part of our absolute cameras. We can obtain the 3x3 rotation matrix
    # `R_absolute` that corresponds to `log_R_absolute` with:
    #
    # `R_absolute = so3_exponential_map(log_R_absolute)`
    #

    fxfyu0v0 = cameras_params[cameras_id_per_image]
    cameras_absolute_gt = PerspectiveCameras(
        focal_length=fxfyu0v0[:, :2],
        principal_point=fxfyu0v0[:, 2:],
        R=cameras_R,
        T=cameras_T,
        device=device,
    )

    # Normally, the points_2d are the one we should use to minimize reprojection errors.
    # But we have been dealing with unstability, so we can reproject the 3D points instead and use their reprojection
    # since we assume Colmap's bundle adjuster to have converged alone before.
    use_3d_points = True
    if use_3d_points:
        with torch.no_grad():
            padded_points = list_to_padded(
                [points_3d[visibility[c]] for c in range(N)], pad_value=1e-3)
            points_2D_gt = cameras_absolute_gt.transform_points(
                padded_points, eps=1e-4)[:, :, :2]
            relative_points_gt = padded_points @ cameras_R + cameras_T

    # Starting point is normally points_3d and camera_R and camera_T
    # For stability test, you can try to add noise and see if the otpitmization
    # gets back to intial state (spoiler alert, it's complicated)
    # Set noise and shift to 0 for a normal starting point
    noise = 0
    shift = 0.1
    points_init = points_3d + noise * torch.randn(
        points_3d.shape, dtype=torch.float32, device=device) + shift

    log_R_init = so3_log_map(cameras_R) + noise * torch.randn(
        N, 3, dtype=torch.float32, device=device)
    T_init = cameras_T + noise * torch.randn(
        cameras_T.shape, dtype=torch.float32, device=device) - shift
    cams_init = cameras_params  # + noise * torch.randn(cameras_params.shape, dtype=torch.float32, device=device)

    # instantiate a copy of the initialization of log_R / T
    log_R = log_R_init.clone().detach()
    log_R.requires_grad = True
    T = T_init.clone().detach()
    T.requires_grad = True

    cams_params = cams_init.clone().detach()
    cams_params.requires_grad = True

    points = points_init.clone().detach()
    points.requires_grad = True

    # init the optimizer
    # Different learning rates per parameter ? By intuition I'd say that it should be higher for T and lower for log_R
    # Params could be optimized as well but it's unlikely to be interesting
    param_groups = [{
        'params': points,
        'lr': args.lr
    }, {
        'params': log_R,
        'lr': 0.1 * args.lr
    }, {
        'params': T,
        'lr': 2 * args.lr
    }, {
        'params': cams_params,
        'lr': 0
    }]
    optimizer = torch.optim.SGD(param_groups, lr=args.lr, momentum=0.9)

    # run the optimization
    n_iter = 200000  # fix the number of iterations
    # Compute inliers
    # In the model, some 3d points have their reprojection way off compared to the
    # target 2d point. It is potentially a great source of instability. inliers is
    # keeping track of those problematic points to discard them from optimization
    discard_outliers = True
    if discard_outliers:
        with torch.no_grad():
            padded_points = list_to_padded(
                [points_3d[visibility[c]] for c in range(N)], pad_value=1e-3)
            projected_points = cameras_absolute_gt.transform_points(
                padded_points, eps=1e-4)[:, :, :2]
            points_distance = ((projected_points[nonzer] -
                                points_2D_gt[nonzer])**2).sum(dim=1)
            inliers = (points_distance < 100).clone().detach()
            print(inliers)
    else:
        inliers = points_2D_gt[nonzer] == points_2D_gt[
            nonzer]  # All true, except NaNs
    loss_log = []
    cam_dist_log = []
    pts_dist_log = []
    for it in range(n_iter):
        # re-init the optimizer gradients
        optimizer.zero_grad()
        R = so3_exponential_map(log_R)

        fxfyu0v0 = cams_params[cameras_id_per_image]
        # get the current absolute cameras
        cameras_absolute = PerspectiveCameras(
            focal_length=fxfyu0v0[:, :2],
            principal_point=fxfyu0v0[:, 2:],
            R=R,
            T=T,
            device=device,
        )

        padded_points = list_to_padded(
            [points[visibility[c]] for c in range(N)], pad_value=1e-3)

        # two ways of optimizing :
        # 1) minimize 2d projection error. Potentially unstable, especially with very close points.
        # This is problematic as close points are the ones with which we want the pose modification to be low
        # but gradient descent makes them with the highest step size. We can maybe use Adam, but unstability remains.
        #
        # 2) minimize 3d relative position error (initial 3d relative position is considered groundtruth). No more unstability for very close points.
        # 2d reprojection error is not guaranteed to be minimized though

        minimize_2d = True
        chamfer_weight = 1e3
        verbose = True

        chamfer_dist = chamfer_distance(ref_pointcloud[None], points[None])[0]
        if minimize_2d:
            projected_points_3D = cameras_absolute.transform_points(
                padded_points, eps=1e-4)[..., :2]
            projected_points = projected_points_3D[:, :, :2]
            # Discard points with a depth < 0 (theoretically impossible)
            inliers = inliers & (projected_points_3D[:, :, 2][nonzer] > 0)

            # Plot point distants for first image
            # distances = (projected_points[0] - points_2D_gt[0]).norm(dim=-1).detach().cpu().numpy()
            # from matplotlib import pyplot as plt
            # plt.plot(distances[:(visibility[0]).sum()])

            # Different loss functions for reprojection error minimization
            # points_distance = smooth_l1_loss(projected_points, points_2D_gt)
            # points_distance = (smooth_l1_loss(projected_points, points_2D_gt, reduction='none')[nonzer]).sum(dim=1)
            proj_error = ((projected_points[nonzer] -
                           points_2D_gt[nonzer])**2).sum(dim=1)
            proj_error_filtered = proj_error[inliers]
        else:
            projected_points_3D = padded_points @ R + T

            # Plot point distants for first image
            # distances = (projected_points_3D[0] - relative_points_gt[0]).norm(dim=-1).detach().cpu().numpy()
            # from matplotlib import pyplot as plt
            # plt.plot(distances[:(visibility[0]).sum()])

            # Different loss functions for reprojection error minimization
            # points_distance = smooth_l1_loss(projected_points, points_2D_gt)
            # points_distance = (smooth_l1_loss(projected_points, points_2D_gt, reduction='none')[nonzer]).sum(dim=1)
            proj_error = ((projected_points_3D[nonzer] -
                           relative_points_gt[nonzer])**2).sum(dim=1)
            proj_error_filtered = proj_error[inliers]

        loss = proj_error_filtered.mean() + chamfer_weight * chamfer_dist
        loss.backward()

        if verbose:
            print("faulty elements (with nan grad) :")
            faulty_points = torch.arange(
                points.shape[0])[points.grad[:, 0] != points.grad[:, 0]]
            faulty_images = torch.arange(
                log_R.shape[0])[log_R.grad[:, 0] != log_R.grad[:, 0]]
            faulty_cams = torch.arange(cams_params.shape[0])[
                cams_params.grad[:, 0] != cams_params.grad[:, 0]]
            faulty_projected_points = torch.arange(
                projected_points.shape[1])[torch.isnan(
                    projected_points.grad).any(dim=2)[0]]

            # Print Tensor that would become NaN, should the gradient be applied
            print("Faulty Rotation (log) and translation")
            print(faulty_images)
            print(log_R[faulty_images])
            print(T[faulty_images])
            print("Faulty 3D colmap points")
            print(faulty_points)
            print(points[faulty_points])
            print("Faulty Cameras")
            print(faulty_cams)
            print(cams_params[faulty_cams])
            print("Faulty 2D points")
            print(projected_points[faulty_projected_points])
            first_faulty_point = points_df.iloc[int(faulty_points[0])]
            related_faulty_images = images_df.loc[
                first_faulty_point["image_ids"][0]]

            print("First faulty point, and images where it is seen")
            print(first_faulty_point)
            print(related_faulty_images)

        # apply the gradients
        optimizer.step()

        # plot and print status message
        if it % 2000 == 0 or it == n_iter - 1:
            camera_distance = calc_camera_distance(cameras_absolute,
                                                   cameras_absolute_gt)
            print(
                'iteration = {}; loss = {}, chamfer distance = {}, camera_distance = {}'
                .format(it, loss, chamfer_distance, camera_distance))
            loss_log.append(loss.item())
            pts_dist_log.append(chamfer_distance.item())
            cam_dist_log.append(camera_distance.item())
        if it % 20000 == 0 or it == n_iter - 1:
            with torch.no_grad():
                from matplotlib import pyplot as plt
                plt.hist(
                    torch.sqrt(proj_error_filtered).detach().cpu().numpy())
        if it % 200000 == 0 or it == n_iter - 1:
            plt.figure()
            plt.plot(loss_log)
            plt.figure()
            plt.plot(pts_dist_log, label="chamfer_dist")
            plt.plot(cam_dist_log, label="cam_dist")
            plt.legend()
            plot_camera_scene(
                cameras_absolute, cameras_absolute_gt, points, ref_pointcloud,
                'iteration={}; chamfer distance={}'.format(
                    it, chamfer_distance))

    print('Optimization finished.')
Example #23
0
    def _corresponding_cameras_alignment_test_case(
        self,
        cameras,
        R_align_gt,
        T_align_gt,
        s_align_gt,
        estimate_scale,
        mode,
        add_noise,
    ):
        batch_size = cameras.R.shape[0]

        # get target camera centers
        R_new = torch.bmm(R_align_gt[None].expand_as(cameras.R), cameras.R)
        T_new = (
            torch.bmm(T_align_gt[None, None].repeat(batch_size, 1, 1), cameras.R)[:, 0]
            + cameras.T
        ) * s_align_gt

        if add_noise != 0.0:
            R_new = torch.bmm(
                R_new, so3_exponential_map(torch.randn_like(T_new) * add_noise)
            )
            T_new += torch.randn_like(T_new) * add_noise

        # create new cameras from R_new and T_new
        cameras_tgt = cameras.clone()
        cameras_tgt.R = R_new
        cameras_tgt.T = T_new

        # align cameras and cameras_tgt
        cameras_aligned = corresponding_cameras_alignment(
            cameras, cameras_tgt, estimate_scale=estimate_scale, mode=mode
        )

        if batch_size <= 2 and mode == "centers":
            # underdetermined case - check only the center alignment error
            # since the rotation and translation are ambiguous here
            self.assertClose(
                cameras_aligned.get_camera_center(),
                cameras_tgt.get_camera_center(),
                atol=max(add_noise * 7.0, 1e-4),
            )

        else:

            def _rmse(a):
                return (torch.norm(a, dim=1, p=2) ** 2).mean().sqrt()

            if add_noise != 0.0:
                # in a noisy case check mean rotation/translation error for
                # extrinsic alignment and root mean center error for center alignment
                if mode == "centers":
                    self.assertNormsClose(
                        cameras_aligned.get_camera_center(),
                        cameras_tgt.get_camera_center(),
                        _rmse,
                        atol=max(add_noise * 10.0, 1e-4),
                    )
                elif mode == "extrinsics":
                    angle_err = so3_relative_angle(
                        cameras_aligned.R, cameras_tgt.R
                    ).mean()
                    self.assertClose(
                        angle_err, torch.zeros_like(angle_err), atol=add_noise * 10.0
                    )
                    self.assertNormsClose(
                        cameras_aligned.T, cameras_tgt.T, _rmse, atol=add_noise * 7.0
                    )
                else:
                    raise ValueError(mode)

            else:
                # compare the rotations and translations of cameras
                self.assertClose(cameras_aligned.R, cameras_tgt.R, atol=3e-4)
                self.assertClose(cameras_aligned.T, cameras_tgt.T, atol=3e-4)
                # compare the centers
                self.assertClose(
                    cameras_aligned.get_camera_center(),
                    cameras_tgt.get_camera_center(),
                    atol=3e-4,
                )