Пример #1
0
def generate_eval_video_cameras(
    train_cameras,
    n_eval_cams: int = 100,
    trajectory_type: str = "figure_eight",
    trajectory_scale: float = 0.2,
    scene_center: Tuple[float, float, float] = (0.0, 0.0, 0.0),
    up: Tuple[float, float, float] = (0.0, 0.0, 1.0),
    focal_length: Optional[torch.FloatTensor] = None,
    principal_point: Optional[torch.FloatTensor] = None,
    time: Optional[torch.FloatTensor] = None,
    infer_up_as_plane_normal: bool = True,
    traj_offset: Optional[Tuple[float, float, float]] = None,
    traj_offset_canonical: Optional[Tuple[float, float, float]] = None,
) -> PerspectiveCameras:
    """
    Generate a camera trajectory rendering a scene from multiple viewpoints.

    Args:
        train_dataset: The training dataset object.
        n_eval_cams: Number of cameras in the trajectory.
        trajectory_type: The type of the camera trajectory. Can be one of:
            circular_lsq_fit: Camera centers follow a trajectory obtained
                by fitting a 3D circle to train_cameras centers.
                All cameras are looking towards scene_center.
            figure_eight: Figure-of-8 trajectory around the center of the
                central camera of the training dataset.
            trefoil_knot: Same as 'figure_eight', but the trajectory has a shape
                of a trefoil knot (https://en.wikipedia.org/wiki/Trefoil_knot).
            figure_eight_knot: Same as 'figure_eight', but the trajectory has a shape
                of a figure-eight knot
                (https://en.wikipedia.org/wiki/Figure-eight_knot_(mathematics)).
        trajectory_scale: The extent of the trajectory.
        up: The "up" vector of the scene (=the normal of the scene floor).
            Active for the `trajectory_type="circular"`.
        scene_center: The center of the scene in world coordinates which all
            the cameras from the generated trajectory look at.
    Returns:
        Dictionary of camera instances which can be used as the test dataset
    """
    if trajectory_type in ("figure_eight", "trefoil_knot",
                           "figure_eight_knot"):
        cam_centers = train_cameras.get_camera_center()
        # get the nearest camera center to the mean of centers
        mean_camera_idx = (((cam_centers -
                             cam_centers.mean(dim=0)[None])**2).sum(dim=1).min(
                                 dim=0).indices)
        # generate the knot trajectory in canonical coords
        if time is None:
            time = torch.linspace(0, 2 * math.pi,
                                  n_eval_cams + 1)[:n_eval_cams]
        else:
            assert time.numel() == n_eval_cams
        if trajectory_type == "trefoil_knot":
            traj = _trefoil_knot(time)
        elif trajectory_type == "figure_eight_knot":
            traj = _figure_eight_knot(time)
        elif trajectory_type == "figure_eight":
            traj = _figure_eight(time)
        else:
            raise ValueError(f"bad trajectory type: {trajectory_type}")
        traj[:, 2] -= traj[:, 2].max()

        # transform the canonical knot to the coord frame of the mean camera
        mean_camera = PerspectiveCameras(
            **{
                k: getattr(train_cameras, k)[[int(mean_camera_idx)]]
                for k in ("focal_length", "principal_point", "R", "T")
            })
        traj_trans = Scale(
            cam_centers.std(dim=0).mean() * trajectory_scale).compose(
                mean_camera.get_world_to_view_transform().inverse())

        if traj_offset_canonical is not None:
            traj_trans = traj_trans.translate(
                torch.FloatTensor(traj_offset_canonical)[None].to(traj))

        traj = traj_trans.transform_points(traj)

        plane_normal = _fit_plane(cam_centers)[:, 0]
        if infer_up_as_plane_normal:
            up = _disambiguate_normal(plane_normal, up)

    elif trajectory_type == "circular_lsq_fit":
        ### fit plane to the camera centers

        # get the center of the plane as the median of the camera centers
        cam_centers = train_cameras.get_camera_center()

        if time is not None:
            angle = time
        else:
            angle = torch.linspace(0, 2.0 * math.pi,
                                   n_eval_cams).to(cam_centers)

        fit = fit_circle_in_3d(
            cam_centers,
            angles=angle,
            offset=angle.new_tensor(traj_offset_canonical)
            if traj_offset_canonical is not None else None,
            up=angle.new_tensor(up),
        )
        traj = fit.generated_points

        # scalethe trajectory
        _t_mu = traj.mean(dim=0, keepdim=True)
        traj = (traj - _t_mu) * trajectory_scale + _t_mu

        plane_normal = fit.normal

        if infer_up_as_plane_normal:
            up = _disambiguate_normal(plane_normal, up)

    else:
        raise ValueError(f"Uknown trajectory_type {trajectory_type}.")

    if traj_offset is not None:
        traj = traj + torch.FloatTensor(traj_offset)[None].to(traj)

    # point all cameras towards the center of the scene
    R, T = look_at_view_transform(
        eye=traj,
        at=(scene_center, ),  # (1, 3)
        up=(up, ),  # (1, 3)
        device=traj.device,
    )

    # get the average focal length and principal point
    if focal_length is None:
        focal_length = train_cameras.focal_length.mean(dim=0).repeat(
            n_eval_cams, 1)
    if principal_point is None:
        principal_point = train_cameras.principal_point.mean(dim=0).repeat(
            n_eval_cams, 1)

    test_cameras = PerspectiveCameras(
        focal_length=focal_length,
        principal_point=principal_point,
        R=R,
        T=T,
        device=focal_length.device,
    )

    # _visdom_plot_scene(
    #     train_cameras,
    #     test_cameras,
    # )

    return test_cameras
Пример #2
0
    def compute(self, points: torch.Tensor, sdf: torch.Tensor,
                mesh_gt: Meshes):
        """
        Rasterize mesh faces from an far camera facing the origin,
        transform the predicted points position to camera view and project to get the normalized image coordinates
        The number of points on the zbuf at the image coordinates that are larger than the predicted points
        determines the sign of sdf
        """
        assert (points.ndim == 2 and points.shape[-1] == 3)
        device = points.device
        faces_per_pixel = 4
        with torch.autograd.no_grad():
            # a point that is definitely outside the mesh as camera center
            ray0 = torch.tensor([2, 2, 2], device=device,
                                dtype=points.dtype).view(1, 3)
            R, T = look_at_view_transform(eye=ray0,
                                          at=((0, 0, 0), ),
                                          up=((0, 0, 1), ))
            cameras = PerspectiveCameras(R=R, T=T, device=device)
            rasterizer = MeshRasterizer(cameras=cameras,
                                        raster_settings=RasterizationSettings(
                                            faces_per_pixel=faces_per_pixel, ))
            fragments = rasterizer(mesh_gt)

            z_predicted = cameras.get_world_to_view_transform(
            ).transform_points(points=points.unsqueeze(0))[..., -1:]
            # normalized pixel (top-left smallest values)
            screen_xy = -cameras.transform_points(points.unsqueeze(0))[..., :2]
            outside_screen = (screen_xy.abs() > 1.0).any(dim=-1)

            # pix_to_face, zbuf, bary_coords, dists
            assert (fragments.zbuf.shape[-1] == faces_per_pixel)
            zbuf = torch.nn.functional.grid_sample(
                fragments.zbuf.permute(0, 3, 1, 2),
                screen_xy.clamp(-1.0, 1.0).view(1, -1, 1, 2),
                align_corners=False,
                mode='nearest')
            zbuf[outside_screen.unsqueeze(1).expand(-1, zbuf.shape[1],
                                                    -1)] = -1.0
            sign = (((zbuf > z_predicted).sum(dim=1) %
                     2) == 0).type_as(points).view(screen_xy.shape[1])
            sign = sign * 2 - 1

        pcls = PointClouds3D(points.unsqueeze(0)).to(device=device)

        points_first_idx = pcls.cloud_to_packed_first_idx()
        max_points = pcls.num_points_per_cloud().max().item()

        # packed representation for faces
        verts_packed = mesh_gt.verts_packed()
        faces_packed = mesh_gt.faces_packed()
        tris = verts_packed[faces_packed]  # (T, 3, 3)
        tris_first_idx = mesh_gt.mesh_to_faces_packed_first_idx()
        max_tris = mesh_gt.num_faces_per_mesh().max().item()

        # point to face distance: shape (P,)
        point_to_face = point_face_distance(points, points_first_idx, tris,
                                            tris_first_idx, max_points)
        point_to_face = sign * torch.sqrt(eps_sqrt(point_to_face))
        loss = (point_to_face - sdf)**2
        return loss