示例#1
0
 def verts_features_padded(self) -> torch.Tensor:
     if self._verts_features_padded is None:
         if self.isempty():
             self._verts_features_padded = torch.zeros((self._N, 0, 3, 0),
                                                       dtype=torch.float32,
                                                       device=self.device)
         else:
             self._verts_features_padded = list_to_padded(
                 self._verts_features_list, pad_value=0.0)
     return self._verts_features_padded
示例#2
0
    def test_list_to_padded(self):
        device = torch.device('cuda:0')
        N = 5
        K = 20
        ndim = 2
        x = []
        for _ in range(N):
            dims = torch.randint(K, size=(ndim, )).tolist()
            x.append(torch.rand(dims, device=device))
        pad_size = [K] * ndim
        x_padded = struct_utils.list_to_padded(x,
                                               pad_size=pad_size,
                                               pad_value=0.0,
                                               equisized=False)

        self.assertEqual(x_padded.shape[1], K)
        self.assertEqual(x_padded.shape[2], K)
        for i in range(N):
            self.assertClose(x_padded[i, :x[i].shape[0], :x[i].shape[1]], x[i])

        # check for no pad size (defaults to max dimension)
        x_padded = struct_utils.list_to_padded(x,
                                               pad_value=0.0,
                                               equisized=False)
        max_size0 = max(y.shape[0] for y in x)
        max_size1 = max(y.shape[1] for y in x)
        self.assertEqual(x_padded.shape[1], max_size0)
        self.assertEqual(x_padded.shape[2], max_size1)
        for i in range(N):
            self.assertClose(x_padded[i, :x[i].shape[0], :x[i].shape[1]], x[i])

        # check for equisized
        x = [torch.rand((K, 10), device=device) for _ in range(N)]
        x_padded = struct_utils.list_to_padded(x, equisized=True)
        self.assertClose(x_padded, torch.stack(x, 0))

        # catch ValueError for invalid dimensions
        with self.assertRaisesRegex(ValueError, 'Pad size must'):
            pad_size = [K] * 4
            struct_utils.list_to_padded(x,
                                        pad_size=pad_size,
                                        pad_value=0.0,
                                        equisized=False)

        # invalid input tensor dimensions
        x = []
        ndim = 3
        for _ in range(N):
            dims = torch.randint(K, size=(ndim, )).tolist()
            x.append(torch.rand(dims, device=device))
        pad_size = [K] * 2
        with self.assertRaisesRegex(ValueError, 'Supports only'):
            x_padded = struct_utils.list_to_padded(x,
                                                   pad_size=pad_size,
                                                   pad_value=0.0,
                                                   equisized=False)
示例#3
0
def _list_to_padded_wrapper(
    x: List[torch.Tensor],
    pad_size: Union[list, tuple, None] = None,
    pad_value: float = 0.0,
) -> torch.Tensor:
    r"""
    This is a wrapper function for
    pytorch3d.structures.utils.list_to_padded function which only accepts
    3-dimensional inputs.

    For this use case, the input x is of shape (F, 3, ...) where only F
    is different for each element in the list

    Transforms a list of N tensors each of shape (Mi, ...) into a single tensor
    of shape (N, pad_size, ...), or (N, max(Mi), ...)
    if pad_size is None.

    Args:
      x: list of Tensors
      pad_size: int specifying the size of the first dimension
        of the padded tensor
      pad_value: float value to be used to fill the padded tensor

    Returns:
      x_padded: tensor consisting of padded input tensors
    """
    N = len(x)
    # pyre-fixme[16]: `Tensor` has no attribute `ndim`.
    dims = x[0].ndim
    reshape_dims = x[0].shape[1:]
    D = torch.prod(torch.tensor(reshape_dims)).item()
    x_reshaped = []
    for y in x:
        if y.ndim != dims and y.shape[1:] != reshape_dims:
            msg = (
                "list_to_padded requires tensors to have the same number of dimensions"
            )
            raise ValueError(msg)
        x_reshaped.append(y.reshape(-1, D))
    x_padded = list_to_padded(x_reshaped,
                              pad_size=pad_size,
                              pad_value=pad_value)
    return x_padded.reshape((N, -1) + reshape_dims)
示例#4
0
    def forward(self, mesh):
        verts = mesh.verts_packed()
        edges = mesh.edges_packed()

        for gconv in self.gconvs:
            verts = F.relu(gconv(verts, edges))

        ### VERTS ###
        verts_idx = mesh.verts_packed_to_mesh_idx()
        verts_size = verts_idx.unique(return_counts=True)[1]
        verts_packed = packed_to_list(verts, tuple(verts_size))
        verts_padded = list_to_padded(verts_packed)

        #         out  = torch.sum(verts_padded, 1)/verts_size.view(-1,1)
        out = torch.max(verts_padded, 1, keepdim=True)[0]
        out = F.relu(self.fc1(out))
        out = F.relu(self.fc2(out))
        out = self.fc3(out)
        out = torch.sigmoid(out)
        return out
    def test_list_to_padded(self):
        device = torch.device("cuda:0")
        N = 5
        K = 20
        for ndim in [1, 2, 3, 4]:
            x = []
            for _ in range(N):
                dims = torch.randint(K, size=(ndim,)).tolist()
                x.append(torch.rand(dims, device=device))

            # set 0th element to an empty 1D tensor
            x[0] = torch.tensor([], dtype=x[0].dtype, device=device)

            # set 1st element to an empty tensor with correct number of dims
            x[1] = x[1].new_zeros(*[[0] * ndim])

            pad_size = [K] * ndim
            x_padded = struct_utils.list_to_padded(
                x, pad_size=pad_size, pad_value=0.0, equisized=False
            )

            for dim in range(ndim):
                self.assertEqual(x_padded.shape[dim + 1], K)

            self._check_list_to_padded_slices(x, x_padded, ndim)

            # check for no pad size (defaults to max dimension)
            x_padded = struct_utils.list_to_padded(x, pad_value=0.0, equisized=False)
            max_sizes = (
                max(
                    (0 if (y.nelement() == 0 and y.ndim == 1) else y.shape[dim])
                    for y in x
                )
                for dim in range(ndim)
            )
            for dim, max_size in enumerate(max_sizes):
                self.assertEqual(x_padded.shape[dim + 1], max_size)

            self._check_list_to_padded_slices(x, x_padded, ndim)

            # check for equisized
            x = [torch.rand((K, *([10] * (ndim - 1))), device=device) for _ in range(N)]
            x_padded = struct_utils.list_to_padded(x, equisized=True)
            self.assertClose(x_padded, torch.stack(x, 0))

        # catch ValueError for invalid dimensions
        with self.assertRaisesRegex(ValueError, "Pad size must"):
            pad_size = [K] * (ndim + 1)
            struct_utils.list_to_padded(
                x, pad_size=pad_size, pad_value=0.0, equisized=False
            )

        # invalid input tensor dimensions
        x = []
        ndim = 3
        for _ in range(N):
            dims = torch.randint(K, size=(ndim,)).tolist()
            x.append(torch.rand(dims, device=device))
        pad_size = [K] * 2
        with self.assertRaisesRegex(ValueError, "Pad size must"):
            x_padded = struct_utils.list_to_padded(
                x, pad_size=pad_size, pad_value=0.0, equisized=False
            )
    def test_padded_to_packed(self):
        device = torch.device("cuda:0")
        N = 5
        K = 20
        ndim = 2
        dims = [K] * ndim
        x = torch.rand([N] + dims, device=device)

        # Case 1: no split_size or pad_value provided
        # Check output is just the flattened input.
        x_packed = struct_utils.padded_to_packed(x)
        self.assertTrue(x_packed.shape == (x.shape[0] * x.shape[1], x.shape[2]))
        self.assertClose(x_packed, x.reshape(-1, K))

        # Case 2: pad_value is provided.
        # Check each section of the packed tensor matches the
        # corresponding unpadded elements of the padded tensor.
        # Check that only rows where all the values are padded
        # are removed in the conversion to packed.
        pad_value = -1
        x_list = []
        split_size = []
        for _ in range(N):
            dim = torch.randint(K, size=(1,)).item()
            # Add some random values in the input which are the same as the pad_value.
            # These should not be filtered out.
            x_list.append(
                torch.randint(low=pad_value, high=10, size=(dim, K), device=device)
            )
            split_size.append(dim)
        x_padded = struct_utils.list_to_padded(x_list, pad_value=pad_value)
        x_packed = struct_utils.padded_to_packed(x_padded, pad_value=pad_value)
        curr = 0
        for i in range(N):
            self.assertClose(x_packed[curr : curr + split_size[i], ...], x_list[i])
            self.assertClose(torch.cat(x_list), x_packed)
            curr += split_size[i]

        # Case 3: split_size is provided.
        # Check each section of the packed tensor matches the corresponding
        # unpadded elements.
        x_packed = struct_utils.padded_to_packed(x_padded, split_size=split_size)
        curr = 0
        for i in range(N):
            self.assertClose(x_packed[curr : curr + split_size[i], ...], x_list[i])
            self.assertClose(torch.cat(x_list), x_packed)
            curr += split_size[i]

        # Case 4: split_size of the wrong shape is provided.
        # Raise an error.
        split_size = torch.randint(1, K, size=(2 * N,)).view(N, 2).unbind(0)
        with self.assertRaisesRegex(ValueError, "1-dimensional"):
            x_packed = struct_utils.padded_to_packed(x_padded, split_size=split_size)

        split_size = torch.randint(1, K, size=(2 * N,)).view(N * 2).tolist()
        with self.assertRaisesRegex(
            ValueError, "same length as inputs first dimension"
        ):
            x_packed = struct_utils.padded_to_packed(x_padded, split_size=split_size)

        # Case 5: both pad_value and split_size are provided.
        # Raise an error.
        with self.assertRaisesRegex(ValueError, "Only one of"):
            x_packed = struct_utils.padded_to_packed(
                x_padded, split_size=split_size, pad_value=-1
            )

        # Case 6: Input has more than 3 dims.
        # Raise an error.
        with self.assertRaisesRegex(ValueError, "Supports only"):
            x = torch.rand((N, K, K, K, K), device=device)
            split_size = torch.randint(1, K, size=(N,)).tolist()
            struct_utils.padded_to_packed(x, split_size=split_size)
示例#7
0
    def test_padded_to_packed(self):
        N = 2
        # Case where each face in the mesh has 3 unique uv vertex indices
        # - i.e. even if a vertex is shared between multiple faces it will
        # have a unique uv coordinate for each face.
        faces_uvs_list = [
            torch.tensor([[0, 1, 2], [3, 5, 4], [7, 6, 8]]),
            torch.tensor([[0, 1, 2], [3, 4, 5]]),
        ]  # (N, 3, 3)
        verts_uvs_list = [torch.ones(9, 2), torch.ones(6, 2)]
        faces_uvs_padded = list_to_padded(faces_uvs_list, pad_value=-1)
        verts_uvs_padded = list_to_padded(verts_uvs_list)
        tex = Textures(
            maps=torch.ones((N, 16, 16, 3)),
            faces_uvs=faces_uvs_padded,
            verts_uvs=verts_uvs_padded,
        )

        # This is set inside Meshes when textures is passed as an input.
        # Here we set _num_faces_per_mesh and _num_verts_per_mesh explicity.
        tex1 = tex.clone()
        tex1._num_faces_per_mesh = (
            faces_uvs_padded.gt(-1).all(-1).sum(-1).tolist())
        tex1._num_verts_per_mesh = torch.tensor([5, 4])
        faces_packed = tex1.faces_uvs_packed()
        verts_packed = tex1.verts_uvs_packed()
        faces_list = tex1.faces_uvs_list()
        verts_list = tex1.verts_uvs_list()

        for f1, f2 in zip(faces_uvs_list, faces_list):
            self.assertTrue((f1 == f2).all().item())

        for f, v1, v2 in zip(faces_list, verts_list, verts_uvs_list):
            idx = f.unique()
            self.assertTrue((v1[idx] == v2).all().item())

        self.assertTrue(faces_packed.shape == (3 + 2, 3))

        # verts_packed is just flattened verts_padded.
        # split sizes are not used for verts_uvs.
        self.assertTrue(verts_packed.shape == (9 * 2, 2))

        # Case where num_faces_per_mesh is not set
        tex2 = tex.clone()
        faces_packed = tex2.faces_uvs_packed()
        verts_packed = tex2.verts_uvs_packed()
        faces_list = tex2.faces_uvs_list()
        verts_list = tex2.verts_uvs_list()

        # Packed is just flattened padded as num_faces_per_mesh
        # has not been provided.
        self.assertTrue(verts_packed.shape == (9 * 2, 2))
        self.assertTrue(faces_packed.shape == (3 * 2, 3))

        for i in range(N):
            self.assertTrue((
                faces_list[i] == faces_uvs_padded[i,
                                                  ...].squeeze()).all().item())

        for i in range(N):
            self.assertTrue((
                verts_list[i] == verts_uvs_padded[i,
                                                  ...].squeeze()).all().item())
def corresponding_points_alignment(
    X: Union[torch.Tensor, "Pointclouds"],
    Y: Union[torch.Tensor, "Pointclouds"],
    weights: Union[torch.Tensor, List[torch.Tensor], None] = None,
    estimate_scale: bool = False,
    allow_reflection: bool = False,
    eps: float = 1e-9,
) -> SimilarityTransform:
    """
    Finds a similarity transformation (rotation `R`, translation `T`
    and optionally scale `s`)  between two given sets of corresponding
    `d`-dimensional points `X` and `Y` such that:

    `s[i] X[i] R[i] + T[i] = Y[i]`,

    for all batch indexes `i` in the least squares sense.

    The algorithm is also known as Umeyama [1].

    Args:
        **X**: Batch of `d`-dimensional points of shape `(minibatch, num_point, d)`
            or a `Pointclouds` object.
        **Y**: Batch of `d`-dimensional points of shape `(minibatch, num_point, d)`
            or a `Pointclouds` object.
        **weights**: Batch of non-negative weights of
            shape `(minibatch, num_point)` or list of `minibatch` 1-dimensional
            tensors that may have different shapes; in that case, the length of
            i-th tensor should be equal to the number of points in X_i and Y_i.
            Passing `None` means uniform weights.
        **estimate_scale**: If `True`, also estimates a scaling component `s`
            of the transformation. Otherwise assumes an identity
            scale and returns a tensor of ones.
        **allow_reflection**: If `True`, allows the algorithm to return `R`
            which is orthonormal but has determinant==-1.
        **eps**: A scalar for clamping to avoid dividing by zero. Active for the
            code that estimates the output scale `s`.

    Returns:
        3-element named tuple `SimilarityTransform` containing
        - **R**: Batch of orthonormal matrices of shape `(minibatch, d, d)`.
        - **T**: Batch of translations of shape `(minibatch, d)`.
        - **s**: batch of scaling factors of shape `(minibatch, )`.

    References:
        [1] Shinji Umeyama: Least-Suqares Estimation of
        Transformation Parameters Between Two Point Patterns
    """

    # make sure we convert input Pointclouds structures to tensors
    Xt, num_points = oputil.convert_pointclouds_to_tensor(X)
    Yt, num_points_Y = oputil.convert_pointclouds_to_tensor(Y)

    if (Xt.shape != Yt.shape) or (num_points != num_points_Y).any():
        raise ValueError(
            "Point sets X and Y have to have the same \
            number of batches, points and dimensions."
        )
    if weights is not None:
        if isinstance(weights, list):
            if any(np != w.shape[0] for np, w in zip(num_points, weights)):
                raise ValueError(
                    "number of weights should equal to the "
                    + "number of points in the point cloud."
                )
            weights = [w[..., None] for w in weights]
            weights = strutil.list_to_padded(weights)[..., 0]

        if Xt.shape[:2] != weights.shape:
            raise ValueError("weights should have the same first two dimensions as X.")

    b, n, dim = Xt.shape

    if (num_points < Xt.shape[1]).any() or (num_points < Yt.shape[1]).any():
        # in case we got Pointclouds as input, mask the unused entries in Xc, Yc
        mask = (
            torch.arange(n, dtype=torch.int64, device=Xt.device)[None]
            < num_points[:, None]
        ).type_as(Xt)
        weights = mask if weights is None else mask * weights.type_as(Xt)

    # compute the centroids of the point sets
    Xmu = oputil.wmean(Xt, weight=weights, eps=eps)
    Ymu = oputil.wmean(Yt, weight=weights, eps=eps)

    # mean-center the point sets
    Xc = Xt - Xmu
    Yc = Yt - Ymu

    total_weight = torch.clamp(num_points, 1)
    # special handling for heterogeneous point clouds and/or input weights
    if weights is not None:
        Xc *= weights[:, :, None]
        Yc *= weights[:, :, None]
        total_weight = torch.clamp(weights.sum(1), eps)

    if (num_points < (dim + 1)).any():
        warnings.warn(
            "The size of one of the point clouds is <= dim+1. "
            + "corresponding_points_alignment cannot return a unique rotation."
        )

    # compute the covariance XYcov between the point sets Xc, Yc
    XYcov = torch.bmm(Xc.transpose(2, 1), Yc)
    XYcov = XYcov / total_weight[:, None, None]

    # decompose the covariance matrix XYcov
    U, S, V = torch.svd(XYcov)

    # catch ambiguous rotation by checking the magnitude of singular values
    if (S.abs() <= AMBIGUOUS_ROT_SINGULAR_THR).any() and not (
        num_points < (dim + 1)
    ).any():
        warnings.warn(
            "Excessively low rank of "
            + "cross-correlation between aligned point clouds. "
            + "corresponding_points_alignment cannot return a unique rotation."
        )

    # identity matrix used for fixing reflections
    E = torch.eye(dim, dtype=XYcov.dtype, device=XYcov.device)[None].repeat(b, 1, 1)

    if not allow_reflection:
        # reflection test:
        #   checks whether the estimated rotation has det==1,
        #   if not, finds the nearest rotation s.t. det==1 by
        #   flipping the sign of the last singular vector U
        R_test = torch.bmm(U, V.transpose(2, 1))
        E[:, -1, -1] = torch.det(R_test)

    # find the rotation matrix by composing U and V again
    R = torch.bmm(torch.bmm(U, E), V.transpose(2, 1))

    if estimate_scale:
        # estimate the scaling component of the transformation
        trace_ES = (torch.diagonal(E, dim1=1, dim2=2) * S).sum(1)
        Xcov = (Xc * Xc).sum((1, 2)) / total_weight

        # the scaling component
        s = trace_ES / torch.clamp(Xcov, eps)

        # translation component
        T = Ymu[:, 0, :] - s[:, None] * torch.bmm(Xmu, R)[:, 0, :]
    else:
        # translation component
        T = Ymu[:, 0, :] - torch.bmm(Xmu, R)[:, 0, :]

        # unit scaling since we do not estimate scale
        s = T.new_ones(b)

    return SimilarityTransform(R, T, s)
示例#9
0
    def init_pointclouds(N, P1, P2, device, requires_grad: bool = True):
        """
        Create 2 pointclouds object and associated padded points/normals tensors by
        starting from lists. The clouds and tensors have the same data. The
        leaf nodes for the clouds are a list of tensors. The padded tensor can be
        used directly as a leaf node.
        """
        p1_lengths = torch.randint(P1,
                                   size=(N, ),
                                   dtype=torch.int64,
                                   device=device)
        p2_lengths = torch.randint(P2,
                                   size=(N, ),
                                   dtype=torch.int64,
                                   device=device)
        weights = torch.rand((N, ), dtype=torch.float32, device=device)

        # list of points and normals tensors
        p1_list = []
        p2_list = []
        n1_list = []
        n2_list = []
        for i in range(N):
            l1 = p1_lengths[i]
            l2 = p2_lengths[i]
            p1_list.append(
                torch.rand((l1, 3), dtype=torch.float32, device=device))
            p2_list.append(
                torch.rand((l2, 3), dtype=torch.float32, device=device))
            n1_list.append(
                torch.rand((l1, 3), dtype=torch.float32, device=device))
            n2_list.append(
                torch.rand((l2, 3), dtype=torch.float32, device=device))

        n1_list = [n / n.norm(dim=-1, p=2, keepdim=True) for n in n1_list]
        n2_list = [n / n.norm(dim=-1, p=2, keepdim=True) for n in n2_list]

        # Clone the lists and initialize padded tensors.
        p1 = list_to_padded([p.clone() for p in p1_list])
        p2 = list_to_padded([p.clone() for p in p2_list])
        n1 = list_to_padded([p.clone() for p in n1_list])
        n2 = list_to_padded([p.clone() for p in n2_list])

        # Set requires_grad for all tensors in the lists and
        # padded tensors.
        if requires_grad:
            for p in p2_list + p1_list + n1_list + n2_list + [p1, p2, n1, n2]:
                p.requires_grad = True

        # Create pointclouds objects
        cloud1 = Pointclouds(points=p1_list, normals=n1_list)
        cloud2 = Pointclouds(points=p2_list, normals=n2_list)

        # Return pointclouds objects and padded tensors
        return points_normals(
            p1_lengths=p1_lengths,
            p2_lengths=p2_lengths,
            cloud1=cloud1,
            cloud2=cloud2,
            p1=p1,
            p2=p2,
            n1=n1,
            n2=n2,
            weights=weights,
        )
示例#10
0
def main(args):
    # set for reproducibility
    torch.manual_seed(42)
    if args.dtype == "float":
        args.dtype = torch.float32
    elif args.dtype == "double":
        args.dtype = torch.float64

    # ## 1. Set up Cameras and load ground truth positions

    # load the SE3 graph of relative/absolute camera positions
    if (args.input_folder / "images.bin").isfile():
        ext = '.bin'
    elif (args.input_folder / "images.txt").isfile():
        ext = '.txt'
    else:
        print('error')
        return
    cameras, images, points3D = read_model(args.input_folder, ext)

    images_df = pd.DataFrame.from_dict(images, orient="index").set_index("id")
    cameras_df = pd.DataFrame.from_dict(cameras,
                                        orient="index").set_index("id")
    points_df = pd.DataFrame.from_dict(points3D,
                                       orient="index").set_index("id")
    print(points_df)
    print(images_df)
    print(cameras_df)

    ref_pointcloud = PyntCloud.from_file(args.ply)
    ref_pointcloud = torch.from_numpy(ref_pointcloud.xyz).to(device,
                                                             dtype=args.dtype)

    points_3d = np.stack(points_df["xyz"].values)
    points_3d = torch.from_numpy(points_3d).to(device, dtype=args.dtype)

    cameras_R = np.stack(
        [qvec2rotmat(q) for _, q in images_df["qvec"].iteritems()])
    cameras_R = torch.from_numpy(cameras_R).to(device,
                                               dtype=args.dtype).transpose(
                                                   1, 2)

    cameras_T = torch.from_numpy(np.stack(images_df["tvec"].values)).to(
        device, dtype=args.dtype)

    cameras_params = torch.from_numpy(np.stack(
        cameras_df["params"].values)).to(device, dtype=args.dtype)
    cameras_params = cameras_params[:, :4]
    print(cameras_params)

    # Constructu visibility map, True at (frame, point) if point is visible by frame, False otherwise
    # Thus, we can ignore reprojection errors for invisible points
    visibility = np.full((cameras_R.shape[0], points_3d.shape[0]), False)
    visibility = pd.DataFrame(visibility,
                              index=images_df.index,
                              columns=points_df.index)

    points_2D_gt = []
    for idx, (pts_ids, xy) in images_df[["point3D_ids", "xys"]].iterrows():
        pts_ids_clean = pts_ids[pts_ids != -1]
        pts_2D = pd.DataFrame(xy[pts_ids != -1], index=pts_ids_clean)
        pts_2D = pts_2D[~pts_2D.index.duplicated(keep=False)].reindex(
            points_df.index).dropna()
        points_2D_gt.append(pts_2D.values)
        visibility.loc[idx, pts_2D.index] = True

    print(visibility)

    visibility = torch.from_numpy(visibility.values).to(device)
    eps = 1e-3
    # Visibility map is very sparse. So we can use Pytorch3d's function to reduce points_2D size
    # to (num_frames, max points seen by frame)
    points_2D_gt = list_to_padded([torch.from_numpy(p) for p in points_2D_gt],
                                  pad_value=eps).to(device, dtype=args.dtype)
    print(points_2D_gt)

    cameras_df["raw_id"] = np.arange(len(cameras_df))
    cameras_id_per_image = torch.from_numpy(
        cameras_df["raw_id"][images_df["camera_id"]].values).to(device)
    # the number of absolute camera positions
    N = len(images_df)
    nonzer = (points_2D_gt != eps).all(dim=-1)

    # print(padded)
    # print(points_2D_gt, points_2D_gt.shape)

    # ## 2. Define optimization functions
    #
    # ### Relative cameras and camera distance
    # We now define two functions crucial for the optimization.
    #
    # **`calc_camera_distance`** compares a pair of cameras.
    # This function is important as it defines the loss that we are minimizing.
    # The method utilizes the `so3_relative_angle` function from the SO3 API.
    #
    # **`get_relative_camera`** computes the parameters of a relative camera
    # that maps between a pair of absolute cameras. Here we utilize the `compose`
    # and `inverse` class methods from the PyTorch3D Transforms API.

    def calc_camera_distance(cam_1, cam_2):
        """
        Calculates the divergence of a batch of pairs of cameras cam_1, cam_2.
        The distance is composed of the cosine of the relative angle between
        the rotation components of the camera extrinsics and the l2 distance
        between the translation vectors.
        """
        # rotation distance
        R_distance = (
            1. - so3_relative_angle(cam_1.R, cam_2.R, cos_angle=True)).mean()
        # translation distance
        T_distance = ((cam_1.T - cam_2.T)**2).sum(1).mean()
        # the final distance is the sum
        return R_distance + T_distance

    # ## 3. Optimization
    # Finally, we start the optimization of the absolute cameras.
    #
    # We use SGD with momentum and optimize over `log_R_absolute` and `T_absolute`.
    #
    # As mentioned earlier, `log_R_absolute` is the axis angle representation of the
    # rotation part of our absolute cameras. We can obtain the 3x3 rotation matrix
    # `R_absolute` that corresponds to `log_R_absolute` with:
    #
    # `R_absolute = so3_exponential_map(log_R_absolute)`
    #

    fxfyu0v0 = cameras_params[cameras_id_per_image]
    cameras_absolute_gt = PerspectiveCameras(
        focal_length=fxfyu0v0[:, :2],
        principal_point=fxfyu0v0[:, 2:],
        R=cameras_R,
        T=cameras_T,
        device=device,
    )

    # Normally, the points_2d are the one we should use to minimize reprojection errors.
    # But we have been dealing with unstability, so we can reproject the 3D points instead and use their reprojection
    # since we assume Colmap's bundle adjuster to have converged alone before.
    use_3d_points = True
    if use_3d_points:
        with torch.no_grad():
            padded_points = list_to_padded(
                [points_3d[visibility[c]] for c in range(N)], pad_value=1e-3)
            points_2D_gt = cameras_absolute_gt.transform_points(
                padded_points, eps=1e-4)[:, :, :2]
            relative_points_gt = padded_points @ cameras_R + cameras_T

    # Starting point is normally points_3d and camera_R and camera_T
    # For stability test, you can try to add noise and see if the otpitmization
    # gets back to intial state (spoiler alert, it's complicated)
    # Set noise and shift to 0 for a normal starting point
    noise = 0
    shift = 0.1
    points_init = points_3d + noise * torch.randn(
        points_3d.shape, dtype=torch.float32, device=device) + shift

    log_R_init = so3_log_map(cameras_R) + noise * torch.randn(
        N, 3, dtype=torch.float32, device=device)
    T_init = cameras_T + noise * torch.randn(
        cameras_T.shape, dtype=torch.float32, device=device) - shift
    cams_init = cameras_params  # + noise * torch.randn(cameras_params.shape, dtype=torch.float32, device=device)

    # instantiate a copy of the initialization of log_R / T
    log_R = log_R_init.clone().detach()
    log_R.requires_grad = True
    T = T_init.clone().detach()
    T.requires_grad = True

    cams_params = cams_init.clone().detach()
    cams_params.requires_grad = True

    points = points_init.clone().detach()
    points.requires_grad = True

    # init the optimizer
    # Different learning rates per parameter ? By intuition I'd say that it should be higher for T and lower for log_R
    # Params could be optimized as well but it's unlikely to be interesting
    param_groups = [{
        'params': points,
        'lr': args.lr
    }, {
        'params': log_R,
        'lr': 0.1 * args.lr
    }, {
        'params': T,
        'lr': 2 * args.lr
    }, {
        'params': cams_params,
        'lr': 0
    }]
    optimizer = torch.optim.SGD(param_groups, lr=args.lr, momentum=0.9)

    # run the optimization
    n_iter = 200000  # fix the number of iterations
    # Compute inliers
    # In the model, some 3d points have their reprojection way off compared to the
    # target 2d point. It is potentially a great source of instability. inliers is
    # keeping track of those problematic points to discard them from optimization
    discard_outliers = True
    if discard_outliers:
        with torch.no_grad():
            padded_points = list_to_padded(
                [points_3d[visibility[c]] for c in range(N)], pad_value=1e-3)
            projected_points = cameras_absolute_gt.transform_points(
                padded_points, eps=1e-4)[:, :, :2]
            points_distance = ((projected_points[nonzer] -
                                points_2D_gt[nonzer])**2).sum(dim=1)
            inliers = (points_distance < 100).clone().detach()
            print(inliers)
    else:
        inliers = points_2D_gt[nonzer] == points_2D_gt[
            nonzer]  # All true, except NaNs
    loss_log = []
    cam_dist_log = []
    pts_dist_log = []
    for it in range(n_iter):
        # re-init the optimizer gradients
        optimizer.zero_grad()
        R = so3_exponential_map(log_R)

        fxfyu0v0 = cams_params[cameras_id_per_image]
        # get the current absolute cameras
        cameras_absolute = PerspectiveCameras(
            focal_length=fxfyu0v0[:, :2],
            principal_point=fxfyu0v0[:, 2:],
            R=R,
            T=T,
            device=device,
        )

        padded_points = list_to_padded(
            [points[visibility[c]] for c in range(N)], pad_value=1e-3)

        # two ways of optimizing :
        # 1) minimize 2d projection error. Potentially unstable, especially with very close points.
        # This is problematic as close points are the ones with which we want the pose modification to be low
        # but gradient descent makes them with the highest step size. We can maybe use Adam, but unstability remains.
        #
        # 2) minimize 3d relative position error (initial 3d relative position is considered groundtruth). No more unstability for very close points.
        # 2d reprojection error is not guaranteed to be minimized though

        minimize_2d = True
        chamfer_weight = 1e3
        verbose = True

        chamfer_dist = chamfer_distance(ref_pointcloud[None], points[None])[0]
        if minimize_2d:
            projected_points_3D = cameras_absolute.transform_points(
                padded_points, eps=1e-4)[..., :2]
            projected_points = projected_points_3D[:, :, :2]
            # Discard points with a depth < 0 (theoretically impossible)
            inliers = inliers & (projected_points_3D[:, :, 2][nonzer] > 0)

            # Plot point distants for first image
            # distances = (projected_points[0] - points_2D_gt[0]).norm(dim=-1).detach().cpu().numpy()
            # from matplotlib import pyplot as plt
            # plt.plot(distances[:(visibility[0]).sum()])

            # Different loss functions for reprojection error minimization
            # points_distance = smooth_l1_loss(projected_points, points_2D_gt)
            # points_distance = (smooth_l1_loss(projected_points, points_2D_gt, reduction='none')[nonzer]).sum(dim=1)
            proj_error = ((projected_points[nonzer] -
                           points_2D_gt[nonzer])**2).sum(dim=1)
            proj_error_filtered = proj_error[inliers]
        else:
            projected_points_3D = padded_points @ R + T

            # Plot point distants for first image
            # distances = (projected_points_3D[0] - relative_points_gt[0]).norm(dim=-1).detach().cpu().numpy()
            # from matplotlib import pyplot as plt
            # plt.plot(distances[:(visibility[0]).sum()])

            # Different loss functions for reprojection error minimization
            # points_distance = smooth_l1_loss(projected_points, points_2D_gt)
            # points_distance = (smooth_l1_loss(projected_points, points_2D_gt, reduction='none')[nonzer]).sum(dim=1)
            proj_error = ((projected_points_3D[nonzer] -
                           relative_points_gt[nonzer])**2).sum(dim=1)
            proj_error_filtered = proj_error[inliers]

        loss = proj_error_filtered.mean() + chamfer_weight * chamfer_dist
        loss.backward()

        if verbose:
            print("faulty elements (with nan grad) :")
            faulty_points = torch.arange(
                points.shape[0])[points.grad[:, 0] != points.grad[:, 0]]
            faulty_images = torch.arange(
                log_R.shape[0])[log_R.grad[:, 0] != log_R.grad[:, 0]]
            faulty_cams = torch.arange(cams_params.shape[0])[
                cams_params.grad[:, 0] != cams_params.grad[:, 0]]
            faulty_projected_points = torch.arange(
                projected_points.shape[1])[torch.isnan(
                    projected_points.grad).any(dim=2)[0]]

            # Print Tensor that would become NaN, should the gradient be applied
            print("Faulty Rotation (log) and translation")
            print(faulty_images)
            print(log_R[faulty_images])
            print(T[faulty_images])
            print("Faulty 3D colmap points")
            print(faulty_points)
            print(points[faulty_points])
            print("Faulty Cameras")
            print(faulty_cams)
            print(cams_params[faulty_cams])
            print("Faulty 2D points")
            print(projected_points[faulty_projected_points])
            first_faulty_point = points_df.iloc[int(faulty_points[0])]
            related_faulty_images = images_df.loc[
                first_faulty_point["image_ids"][0]]

            print("First faulty point, and images where it is seen")
            print(first_faulty_point)
            print(related_faulty_images)

        # apply the gradients
        optimizer.step()

        # plot and print status message
        if it % 2000 == 0 or it == n_iter - 1:
            camera_distance = calc_camera_distance(cameras_absolute,
                                                   cameras_absolute_gt)
            print(
                'iteration = {}; loss = {}, chamfer distance = {}, camera_distance = {}'
                .format(it, loss, chamfer_distance, camera_distance))
            loss_log.append(loss.item())
            pts_dist_log.append(chamfer_distance.item())
            cam_dist_log.append(camera_distance.item())
        if it % 20000 == 0 or it == n_iter - 1:
            with torch.no_grad():
                from matplotlib import pyplot as plt
                plt.hist(
                    torch.sqrt(proj_error_filtered).detach().cpu().numpy())
        if it % 200000 == 0 or it == n_iter - 1:
            plt.figure()
            plt.plot(loss_log)
            plt.figure()
            plt.plot(pts_dist_log, label="chamfer_dist")
            plt.plot(cam_dist_log, label="cam_dist")
            plt.legend()
            plot_camera_scene(
                cameras_absolute, cameras_absolute_gt, points, ref_pointcloud,
                'iteration={}; chamfer distance={}'.format(
                    it, chamfer_distance))

    print('Optimization finished.')