Exemplo n.º 1
0
def get_relative_camera(cams, edges):
    """
    For each pair of indices (i,j) in "edges" generate a camera
    that maps from the coordinates of the camera cams[i] to
    the coordinates of the camera cams[j] - Pytorch3D
    """

    # first generate the world-to-view Transform3d objects of each
    # camera pair (i, j) according to the edges argument
    trans_i, trans_j = [
        PerspectiveCameras(
            R=cams.R[edges[:, i]],
            T=cams.T[edges[:, i]],
            device=device,
        ).get_world_to_view_transform() for i in (0, 1)
    ]

    # compose the relative transformation as g_i^{-1} g_j
    trans_rel = trans_i.inverse().compose(trans_j)

    # generate a camera from the relative transform
    matrix_rel = trans_rel.get_matrix()
    cams_relative = PerspectiveCameras(
        R=matrix_rel[:, :3, :3],
        T=matrix_rel[:, 3, :3],
        device=device,
    )
    return cams_relative
Exemplo n.º 2
0
def plot_cams_from_poses(pose_gt, pose_pred, device: str):
    """
    """
    R_gt, T_gt = pose_gt
    cameras_gt = PerspectiveCameras(device=device, R=R_gt, T=T_gt)
    if pose_pred is None:
        fig = plot_trajectory_cameras(cameras_gt, device)
        return fig

    R_pred, T_pred = pose_pred
    cameras_pred = PerspectiveCameras(device=device, R=R_pred, T=T_pred)
    fig = plot_camera_scene(cameras_pred, cameras_gt, "final_preds", device)

    return fig
Exemplo n.º 3
0
def init_cameras(
        batch_size: int = 10,
        image_size: Optional[Tuple[int, int]] = (50, 50),
        ndc: bool = False,
):
    """
    Initialize a batch of cameras whose extrinsics rotate the cameras around
    the world's y axis.
    Depending on whether we want an NDC-space (`ndc==True`) or a screen-space camera,
    the camera's focal length and principal point are initialized accordingly:
        For `ndc==False`, p0=focal_length=image_size/2.
        For `ndc==True`, focal_length=1.0, p0 = 0.0.
    The the z-coordinate of the translation vector of each camera is fixed to 1.5.
    """
    device = torch.device("cuda:0")

    # trivial rotations
    R = init_uniform_y_rotations(batch_size=batch_size, device=device)

    # move camera 1.5 m away from the scene center
    T = torch.zeros((batch_size, 3), device=device)
    T[:, 2] = 1.5

    if ndc:
        p0 = torch.zeros(batch_size, 2, device=device)
        focal = torch.ones(batch_size, device=device)
    else:
        p0 = torch.ones(batch_size, 2, device=device)
        p0[:, 0] *= image_size[1] * 0.5
        p0[:, 1] *= image_size[0] * 0.5
        focal = max(*image_size) * torch.ones(batch_size, device=device)

    # convert to a Camera object
    cameras = PerspectiveCameras(focal, p0, R=R, T=T, device=device)
    return cameras
Exemplo n.º 4
0
    def test_cameras(self):
        """
        DVR cameras
        """
        device = torch.device('cuda:0')
        input_dir = '/home/ywang/Documents/points/neural_splatter/differentiable_volumetric_rendering_upstream/data/DTU/scan106/scan106'
        out_dir = os.path.join('tests', 'outputs', 'test_dvr_data')
        if not os.path.exists(out_dir):
            os.makedirs(out_dir)

        dvr_camera_file = os.path.join(input_dir, 'cameras.npz')
        dvr_camera_dict = np.load(dvr_camera_file)
        n_views = len(glob.glob(os.path.join(input_dir, 'image', '*.png')))

        focal_lengths = dvr_camera_dict['camera_mat_0'][(0,1),(0,1)].reshape(1,2)
        principal_point = dvr_camera_dict['camera_mat_0'][(0,1),(2,2)].reshape(1,2)
        cameras = PerspectiveCameras(focal_length=focal_lengths, principal_point=principal_point).to(device)
        # Define the settings for rasterization and shading.
        # Refer to raster_points.py for explanations of these parameters.
        raster_settings = RasterizationSettings(
            image_size=512,
            blur_radius=0.0,
            faces_per_pixel=5,
            # this setting controls whether naive or coarse-to-fine rasterization is used
            bin_size=None,
            max_faces_per_bin=None  # this setting is for coarse rasterization
        )
        renderer = MeshRenderer(
            rasterizer=MeshRasterizer(
                cameras=None, raster_settings=raster_settings),
            shader=SoftPhongShader(device=device)
        )
        mesh = trimesh.load_mesh('/home/ywang/Documents/points/neural_splatter/differentiable_volumetric_rendering_upstream/out/multi_view_reconstruction/birds/ours_depth_mvs/vis/000_0000477500.ply')
        textures = TexturesVertex(verts_features=torch.ones(
            1, mesh.vertices.shape[0], 3)).to(device=device)
        meshes = Meshes(verts=[torch.tensor(mesh.vertices).float()], faces=[torch.tensor(mesh.faces)],
                        textures=textures).to(device=device)
        for i in range(n_views):
            transform_mat = torch.from_numpy(dvr_camera_dict['scale_mat_%d' % i].T @ dvr_camera_dict['world_mat_%d' % i].T).to(device).unsqueeze(0).float()
            cameras.R, cameras.T = decompose_to_R_and_t(transform_mat)
            cameras._N = cameras.R.shape[0]
            imgs = renderer(meshes, cameras=cameras, zfar=1e4, znear=1.0)
            import pdb; pdb.set_trace()
            imageio.imwrite(os.path.join(out_dir, '%06d.png' % i), (imgs[0].detach().cpu().numpy()*255).astype('uint8'))
def cli():
    """
    Basic example for the pulsar sphere renderer using the PyTorch3D interface.

    Writes to `basic-pt3d.png`.
    """
    LOGGER.info("Rendering on GPU...")
    torch.manual_seed(1)
    n_points = 10
    width = 1_000
    height = 1_000
    device = torch.device("cuda")
    # Generate sample data.
    vert_pos = torch.rand(n_points, 3, dtype=torch.float32,
                          device=device) * 10.0
    vert_pos[:, 2] += 25.0
    vert_pos[:, :2] -= 5.0
    vert_col = torch.rand(n_points, 3, dtype=torch.float32, device=device)
    pcl = Pointclouds(points=vert_pos[None, ...], features=vert_col[None, ...])
    # Alternatively, you can also use the look_at_view_transform to get R and T:
    # R, T = look_at_view_transform(
    #     dist=30.0, elev=0.0, azim=180.0, at=((0.0, 0.0, 30.0),), up=((0, 1, 0),),
    # )
    cameras = PerspectiveCameras(
        # The focal length must be double the size for PyTorch3D because of the NDC
        # coordinates spanning a range of two - and they must be normalized by the
        # sensor width (see the pulsar example). This means we need here
        # 5.0 * 2.0 / 2.0 to get the equivalent results as in pulsar.
        focal_length=(5.0 * 2.0 / 2.0, ),
        R=torch.eye(3, dtype=torch.float32, device=device)[None, ...],
        T=torch.zeros((1, 3), dtype=torch.float32, device=device),
        image_size=((width, height), ),
        device=device,
    )
    vert_rad = torch.rand(n_points, dtype=torch.float32, device=device)
    raster_settings = PointsRasterizationSettings(
        image_size=(width, height),
        radius=vert_rad,
    )
    rasterizer = PointsRasterizer(cameras=cameras,
                                  raster_settings=raster_settings)
    renderer = PulsarPointsRenderer(rasterizer=rasterizer).to(device)
    # Render.
    image = renderer(
        pcl,
        gamma=(1.0e-1, ),  # Renderer blending parameter gamma, in [1., 1e-5].
        znear=(1.0, ),
        zfar=(45.0, ),
        radius_world=True,
        bg_col=torch.ones((3, ), dtype=torch.float32, device=device),
    )[0]
    LOGGER.info("Writing image to `%s`.", path.abspath("basic-pt3d.png"))
    imageio.imsave("basic-pt3d.png",
                   (image.cpu().detach() * 255.0).to(torch.uint8).numpy())
    LOGGER.info("Done.")
Exemplo n.º 6
0
    def test_raysampler_caching(self, batch_size=10):
        """
        Tests the consistency of the NeRF raysampler caching.
        """

        raysampler = NeRFRaysampler(
            min_x=0.0,
            max_x=10.0,
            min_y=0.0,
            max_y=10.0,
            n_pts_per_ray=10,
            min_depth=0.1,
            max_depth=10.0,
            n_rays_per_image=12,
            image_width=10,
            image_height=10,
            stratified=False,
            stratified_test=False,
            invert_directions=True,
        )

        raysampler.eval()

        cameras, rays = [], []

        for _ in range(batch_size):

            R = random_rotations(1)
            T = torch.randn(1, 3)
            focal_length = torch.rand(1, 2) + 0.5
            principal_point = torch.randn(1, 2)

            camera = PerspectiveCameras(
                focal_length=focal_length,
                principal_point=principal_point,
                R=R,
                T=T,
            )

            cameras.append(camera)
            rays.append(raysampler(camera))

        raysampler.precache_rays(cameras, list(range(batch_size)))

        for cam_index, rays_ in enumerate(rays):
            rays_cached_ = raysampler(
                cameras=cameras[cam_index],
                chunksize=None,
                chunk_idx=0,
                camera_hash=cam_index,
                caching=False,
            )

            for v, v_cached in zip(rays_, rays_cached_):
                self.assertTrue(torch.allclose(v, v_cached))
Exemplo n.º 7
0
    def test_probabilistic_raysampler(self, batch_size=1, n_pts_per_ray=60):
        """
        Check that the probabilistic ray sampler does not crash for various
        settings.
        """

        raysampler_grid = NeRFRaysampler(
            min_x=0.0,
            max_x=10.0,
            min_y=0.0,
            max_y=10.0,
            n_pts_per_ray=n_pts_per_ray,
            min_depth=1.0,
            max_depth=10.0,
            n_rays_per_image=12,
            image_width=10,
            image_height=10,
            stratified=False,
            stratified_test=False,
            invert_directions=True,
        )

        R = random_rotations(batch_size)
        T = torch.randn(batch_size, 3)
        focal_length = torch.rand(batch_size, 2) + 0.5
        principal_point = torch.randn(batch_size, 2)
        camera = PerspectiveCameras(
            focal_length=focal_length,
            principal_point=principal_point,
            R=R,
            T=T,
        )

        raysampler_grid.eval()

        ray_bundle = raysampler_grid(cameras=camera)

        ray_weights = torch.rand_like(ray_bundle.lengths)

        # Just check that we dont crash for all possible settings.
        for stratified_test in (True, False):
            for stratified in (True, False):
                raysampler_prob = ProbabilisticRaysampler(
                    n_pts_per_ray=n_pts_per_ray,
                    stratified=stratified,
                    stratified_test=stratified_test,
                    add_input_samples=True,
                )
                for mode in ("train", "eval"):
                    getattr(raysampler_prob, mode)()
                    for _ in range(10):
                        raysampler_prob(ray_bundle, ray_weights)
Exemplo n.º 8
0
            def on_change(value):
                img_copy = image.copy()

                x = (cv2.getTrackbarPos("x", "image") - 1000) / 1000
                y = (cv2.getTrackbarPos("y", "image") - 1000) / 1000
                z = cv2.getTrackbarPos("z", "image") / 1000
                rx = cv2.getTrackbarPos("rx", "image")
                ry = cv2.getTrackbarPos("ry", "image")
                rz = cv2.getTrackbarPos("rz", "image")

                T = torch.tensor([[x, y, z]],
                                 dtype=torch.float32,
                                 device=device)
                R = Rotation.from_euler("zyx", [rz, ry, rx],
                                        degrees=True).as_matrix()

                renderR = torch.from_numpy(R.T.reshape((1, 3, 3))).to(device)

                cameras = PerspectiveCameras(
                    R=renderR,
                    T=T,
                    focal_length=-self.f,
                    principal_point=self.p,
                    image_size=(self.img_size, ),
                    device=device,
                )

                raster_settings = RasterizationSettings(
                    image_size=(self.intrinsics.height, self.intrinsics.width),
                    blur_radius=0.0,
                    faces_per_pixel=1,
                )
                renderer = MeshRenderer(
                    rasterizer=MeshRasterizer(cameras=cameras,
                                              raster_settings=raster_settings),
                    shader=SoftPhongShader(
                        device=device,
                        cameras=cameras,
                    ),
                )
                overlay = renderer(mesh)[0, ..., :3].cpu().numpy()[:, :, ::-1]
                render_img = overlay * 0.7 + img_copy / 255 * 0.3
                cv2.imshow(windowName, render_img)

                store_and_exit = cv2.getTrackbarPos(
                    "0 : Manual Match \n1 : Store and Exit", "image")
                if store_and_exit:
                    cv2.destroyAllWindows()
                    pose[mesh_name] = {
                        "translation": T.cpu().numpy(),
                        "rotation": R
                    }
Exemplo n.º 9
0
    def __init__(self):
        self.H, self.W = 249, 125
        self.image_size = (self.H, self.W)
        self.camera_ndc = PerspectiveCameras(
            focal_length=1.0,
            image_size=(self.image_size, ),
            in_ndc=True,
            T=torch.tensor([[0.0, 0.0, 0.0], [-1.0, self.H / self.W, 0.0]]),
            principal_point=((-0.0, -0.0), (1.0, -self.H / self.W)),
        )
        # Note how principal point is  specifiied
        self.camera_screen = PerspectiveCameras(
            focal_length=self.W / 2.0,
            principal_point=((self.W / 2.0, self.H / 2.0), (0.0, self.H)),
            image_size=(self.image_size, ),
            T=torch.tensor([[0.0, 0.0, 0.0], [-1.0, self.H / self.W, 0.0]]),
            in_ndc=False,
        )

        # 81 is more than half of 125, 113 is a bit less than half of 249
        self.x, self.y = 81, 113
        self.point = [-0.304, 0.176, 1]
Exemplo n.º 10
0
def rasterize_mc_samples(
    xys: torch.Tensor,
    feats: torch.Tensor,
    image_size_hw: Tuple[int, int],
    radius: float = 0.03,
    topk: int = 5,
    masks: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Rasterizes Monte-Carlo sampled features back onto the image.

    Specifically, the code uses the PyTorch3D point rasterizer to render
    a z-flat point cloud composed of the xy MC locations and their features.

    Args:
        xys: B x N x 2 2D point locations in PyTorch3D NDC convention
        feats: B x N x dim tensor containing per-point rendered features.
        image_size_hw: Tuple[image_height, image_width] containing
            the size of rasterized image.
        radius: Rasterization point radius.
        topk: The maximum z-buffer size for the PyTorch3D point cloud rasterizer.
        masks: B x N x 1 tensor containing the alpha mask of the
            rendered features.
    """

    if masks is None:
        masks = torch.ones_like(xys[..., :1])

    feats = torch.cat((feats, masks), dim=-1)
    pointclouds = Pointclouds(
        points=torch.cat([xys, torch.ones_like(xys[..., :1])], dim=-1),
        features=feats,
    )

    data_rendered, render_mask, _ = render_point_cloud_pytorch3d(
        PerspectiveCameras(device=feats.device),
        pointclouds,
        render_size=image_size_hw,
        point_radius=radius,
        topk=topk,
    )

    data_rendered, masks_pt = data_rendered.split(
        [data_rendered.shape[1] - 1, 1], dim=1)
    render_mask = masks_pt * render_mask

    return data_rendered, render_mask
Exemplo n.º 11
0
def render_img(face_shape,
               face_color,
               facemodel,
               image_size=224,
               fx=1015.0,
               fy=1015.0,
               px=112.0,
               py=112.0,
               device='cuda:0'):
    '''
        ref: https://github.com/facebookresearch/pytorch3d/issues/184
        The rendering function (just for test)
        Input:
            face_shape:  Tensor[1, 35709, 3]
            face_color: Tensor[1, 35709, 3] in [0, 1]
            facemodel: contains `tri` (triangles[70789, 3], index start from 1)
    '''
    from pytorch3d.structures import Meshes
    from pytorch3d.renderer.mesh.textures import TexturesVertex
    from pytorch3d.renderer import (PerspectiveCameras, PointLights,
                                    RasterizationSettings, MeshRenderer,
                                    MeshRasterizer, SoftPhongShader,
                                    BlendParams)

    face_color = TexturesVertex(verts_features=face_color.to(device))
    face_buf = torch.from_numpy(facemodel.tri - 1)  # index start from 1
    face_idx = face_buf.unsqueeze(0)

    mesh = Meshes(face_shape.to(device), face_idx.to(device), face_color)

    R = torch.eye(3).view(1, 3, 3).to(device)
    R[0, 0, 0] *= -1.0
    T = torch.zeros([1, 3]).to(device)

    half_size = (image_size - 1.0) / 2
    focal_length = torch.tensor([fx / half_size, fy / half_size],
                                dtype=torch.float32).reshape(1, 2).to(device)
    principal_point = torch.tensor([(half_size - px) / half_size,
                                    (py - half_size) / half_size],
                                   dtype=torch.float32).reshape(1,
                                                                2).to(device)

    cameras = PerspectiveCameras(device=device,
                                 R=R,
                                 T=T,
                                 focal_length=focal_length,
                                 principal_point=principal_point)

    raster_settings = RasterizationSettings(image_size=image_size,
                                            blur_radius=0.0,
                                            faces_per_pixel=1)

    lights = PointLights(device=device,
                         ambient_color=((1.0, 1.0, 1.0), ),
                         diffuse_color=((0.0, 0.0, 0.0), ),
                         specular_color=((0.0, 0.0, 0.0), ),
                         location=((0.0, 0.0, 1e5), ))

    blend_params = BlendParams(background_color=(0.0, 0.0, 0.0))

    renderer = MeshRenderer(rasterizer=MeshRasterizer(
        cameras=cameras, raster_settings=raster_settings),
                            shader=SoftPhongShader(device=device,
                                                   cameras=cameras,
                                                   lights=lights,
                                                   blend_params=blend_params))
    images = renderer(mesh)
    images = torch.clamp(images, 0.0, 1.0)
    return images
Exemplo n.º 12
0
def cameras_from_opencv_projection2(
    R: torch.Tensor,
    tvec: torch.Tensor,
    camera_matrix: torch.Tensor,
    image_size: torch.Tensor,
) -> PerspectiveCameras:
    """
    Converts a batch of OpenCV-conventioned cameras parametrized with the
    rotation matrices `R`, translation vectors `tvec`, and the camera
    calibration matrices `camera_matrix` to `PerspectiveCameras` in PyTorch3D
    convention.
    More specifically, the conversion is carried out such that a projection
    of a 3D shape to the OpenCV-conventioned screen of size `image_size` results
    in the same image as a projection with the corresponding PyTorch3D camera
    to the NDC screen convention of PyTorch3D.
    More specifically, the OpenCV convention projects points to the OpenCV screen
    space as follows:
        ```
        x_screen_opencv = camera_matrix @ (R @ x_world + tvec)
        ```
    followed by the homogenization of `x_screen_opencv`.
    Note:
        The parameters `R, tvec, camera_matrix` correspond to the outputs of
        `cv2.decomposeProjectionMatrix`.
        The `rvec` parameter of the `cv2.projectPoints` is an axis-angle vector
        that can be converted to the rotation matrix `R` expected here by
        calling the `so3_exp_map` function.
    Args:
        R: A batch of rotation matrices of shape `(N, 3, 3)`.
        tvec: A batch of translation vectors of shape `(N, 3)`.
        camera_matrix: A batch of camera calibration matrices of shape `(N, 3, 3)`.
        image_size: A tensor of shape `(N, 2)` containing the sizes of the images
            (height, width) attached to each camera.
    Returns:
        cameras_pytorch3d: A batch of `N` cameras in the PyTorch3D convention.
    """
    focal_length = torch.stack(
        [camera_matrix[:, 0, 0], camera_matrix[:, 1, 1]], dim=-1)
    principal_point = camera_matrix[:, :2, 2]

    # Retype the image_size correctly and flip to width, height.
    image_size_wh = image_size.to(R).flip(dims=(1, ))

    use_ndc = False

    if use_ndc:
        # Get the PyTorch3D focal length and principal point.
        focal_pytorch3d = focal_length / (0.5 * image_size_wh)
        p0_pytorch3d = -(principal_point / (0.5 * image_size_wh) - 1)
    else:
        focal_pytorch3d = focal_length
        p0_pytorch3d = principal_point

    # For R, T we flip x, y axes (opencv screen space has an opposite
    # orientation of screen axes).
    # We also transpose R (opencv multiplies points from the opposite=left side).
    R_pytorch3d = R.clone().permute(0, 2, 1)
    T_pytorch3d = tvec.clone()
    R_pytorch3d[:, :, :2] *= -1
    T_pytorch3d[:, :2] *= -1

    return PerspectiveCameras(R=R_pytorch3d,
                              T=T_pytorch3d,
                              focal_length=focal_pytorch3d,
                              principal_point=p0_pytorch3d,
                              image_size=image_size,
                              in_ndc=use_ndc)
Exemplo n.º 13
0
        ).shift((30, 0))
        cropped_img = crop_box.apply(img)
        return cropped_img


x3d, xface = load_off(mesh_path)

faces = torch.from_numpy(xface)

# TODO: convert verts
verts = torch.from_numpy(x3d)
verts = pre_process_mesh_pascal(verts)
# cameras = OpenGLPerspectiveCameras(device=device, fov=12.0)
cameras = PerspectiveCameras(focal_length=1.0 * 3000,
                             principal_point=((render_image_size[0] / 2,
                                               render_image_size[1] / 2), ),
                             image_size=(render_image_size, ),
                             device=device)

verts_rgb = torch.ones_like(verts)[None] * torch.Tensor([1, 0.85, 0.85]).view(
    1, 1, 3)  # (1, V, 3)
# textures = Textures(verts_rgb=verts_rgb.to(device))
textures = Textures(verts_features=verts_rgb.to(device))
meshes = Meshes(verts=[verts], faces=[faces], textures=textures)
meshes = meshes.to(device)

blend_params = BlendParams(sigma=1e-4, gamma=1e-4)
raster_settings = RasterizationSettings(image_size=render_image_size,
                                        blur_radius=0.0,
                                        faces_per_pixel=1,
                                        bin_size=0)
Exemplo n.º 14
0
    def run_optimization(
        self,
        silhouettes: torch.tensor,
        R: torch.tensor,
        T: torch.tensor,
        writer=None,
        camera_settings=None,
        step: int = 0,
    ):
        """
        Function:
            Runs a batched optimization procedure that aims to minimize 3 reconstruction losses:
                -Silhouette IoU Loss: between input silhouettes and re-projected mesh
                -Mesh Edge consistency
                -Mesh Normal smoothing
            Mini Batching:
                If the number silhouettes is greater than the allowed batch size then a random set of images/poses is sampled for supervision at each step
        Returns:
            -Reconstruction losses: 3 reconstruction losses measured during optimization
            -Timing:
                -Iterations / second
                -Total time elapsed in seconds
        """

        if len(R.shape) == 4:
            R = R.squeeze(1)
            T = T.squeeze(1)

        tf_smaller = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize(self.params.img_size),
            transforms.ToTensor(),
        ])

        images_gt = torch.stack([
            tf_smaller(s.cpu()).to(self.device) for s in silhouettes
        ]).squeeze(1)

        if images_gt.max() > 1.0:
            images_gt = images_gt / 255.0

        loop = tqdm_notebook(range(self.params.mesh_steps))

        start_time = time.time()
        for i in loop:
            batch_indices = (random.choices(list(range(images_gt.shape[0])),
                                            k=self.params.mesh_batch_size) if
                             images_gt.shape[0] > self.params.mesh_batch_size
                             else list(range(images_gt.shape[0])))
            batch_silhouettes = images_gt[batch_indices]

            batch_R, batch_T = R[batch_indices], T[batch_indices]
            # apply right transform on the Twv to adjust the coordinate system shift from EVIMO to PyTorch3D
            if self.params.is_real_data:
                init_R = quaternion_to_matrix(self.init_camera_R)
                batch_R = _broadcast_bmm(batch_R, init_R)
                batch_T = (
                    _broadcast_bmm(batch_T[:, None, :], init_R) +
                    self.init_camera_t.expand(batch_R.shape[0], 1, 3))[:, 0, :]
                focal_length = (torch.tensor([
                    camera_settings[0, 0], camera_settings[1, 1]
                ])[None]).expand(batch_R.shape[0], 2)
                principle_point = (torch.tensor([
                    camera_settings[0, 2], camera_settings[1, 2]
                ])[None]).expand(batch_R.shape[0], 2)
                # FIXME: in this PyTorch3D version, the image_size in RasterizationSettings is (W, H), while in PerspectiveCameras is (H, W)
                # If the future pytorch3d change the format, please change the settings here
                # We hope PyTorch3D will solve this issue in the future
                batch_cameras = PerspectiveCameras(
                    device=self.device,
                    R=batch_R,
                    T=batch_T,
                    focal_length=focal_length,
                    principal_point=principle_point,
                    image_size=((self.params.img_size[1],
                                 self.params.img_size[0]), ))
            else:
                batch_cameras = PerspectiveCameras(device=self.device,
                                                   R=batch_R,
                                                   T=batch_T)

            mesh, laplacian_loss, flatten_loss = self.forward(
                self.params.mesh_batch_size)

            images_pred = self.renderer(mesh,
                                        device=self.device,
                                        cameras=batch_cameras)[..., -1]

            iou_loss = IOULoss().forward(batch_silhouettes, images_pred)

            loss = (iou_loss * self.params.lambda_iou +
                    laplacian_loss * self.params.lambda_laplacian +
                    flatten_loss * self.params.lambda_flatten)

            loop.set_description("Optimizing (loss %.4f)" % loss.data)

            self.losses["iou"].append(iou_loss * self.params.lambda_iou)
            self.losses["laplacian"].append(laplacian_loss *
                                            self.params.lambda_laplacian)
            self.losses["flatten"].append(flatten_loss *
                                          self.params.lambda_flatten)

            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
            if i % (self.params.mesh_show_step /
                    2) == 0 and self.params.mesh_log:
                logging.info(
                    f'Iteration: {i} IOU Loss: {iou_loss.item()} Flatten Loss: {flatten_loss.item()} Laplacian Loss: {laplacian_loss.item()}'
                )

            if i % self.params.mesh_show_step == 0 and self.params.im_show:
                # Write images
                image = images_pred.detach().cpu().numpy()[0]

                if writer:
                    writer.append_data((255 * image).astype(np.uint8))
                plt.imshow(images_pred.detach().cpu().numpy()[0])
                plt.show()
                plt.imshow(batch_silhouettes.detach().cpu().numpy()[0])
                plt.show()
                plot_pointcloud(mesh[0], 'Mesh deformed')
                logging.info(
                    f'Pose of init camera: {self.init_camera_R.detach().cpu().numpy()}, {self.init_camera_t.detach().cpu().numpy()}'
                )

        # Set the final optimized mesh as an internal variable
        self.final_mesh = mesh[0].clone()
        results = dict(
            silhouette_loss=self.losses["iou"]
            [-1].detach().cpu().numpy().tolist(),
            laplacian_loss=self.losses["laplacian"]
            [-1].detach().cpu().numpy().tolist(),
            flatten_loss=self.losses["flatten"]
            [-1].detach().cpu().numpy().tolist(),
            iterations_per_second=self.params.mesh_steps /
            (time.time() - start_time),
            total_time_s=time.time() - start_time,
        )
        if self.is_real_data:
            self.init_pose_R = self.init_camera_R.detach().cpu().numpy()
            self.init_pose_t = self.init_camera_t.detach().cpu().numpy()

        torch.cuda.empty_cache()

        return results
Exemplo n.º 15
0
    def render_final_mesh(self,
                          poses,
                          mode: str,
                          out_size: list,
                          camera_settings=None) -> dict:
        """Renders the final mesh obtained through optimization
            Supports two modes:
                -predict: renders both silhouettes and flat shaded images
                -train: only renders silhouettes
            Returns:
                -dict of renders {'silhouettes': tensor, 'images': tensor}
        """
        R, T = poses
        if len(R.shape) == 4:
            R = R.squeeze(1)
            T = T.squeeze(1)

        sil_renderer = silhouette_renderer(out_size, self.device)
        image_renderer = flat_renderer(out_size, self.device)

        # Create a silhouette projection of the mesh across all views
        all_silhouettes = []
        all_images = []
        for i in range(0, R.shape[0]):
            batch_R, batch_T = R[[i]], T[[i]]
            if self.params.is_real_data:
                init_R = quaternion_to_matrix(self.init_camera_R)
                batch_R = _broadcast_bmm(batch_R, init_R)
                batch_T = (
                    _broadcast_bmm(batch_T[:, None, :], init_R) +
                    self.init_camera_t.expand(batch_R.shape[0], 1, 3))[:, 0, :]
                focal_length = torch.tensor(
                    [camera_settings[0, 0], camera_settings[1, 1]])[None]
                principle_point = torch.tensor(
                    [camera_settings[0, 2], camera_settings[1, 2]])[None]
                t_cameras = PerspectiveCameras(
                    device=self.device,
                    R=batch_R,
                    T=batch_T,
                    focal_length=focal_length,
                    principal_point=principle_point,
                    image_size=((self.params.img_size[1],
                                 self.params.img_size[0]), ))
            else:
                t_cameras = PerspectiveCameras(device=self.device,
                                               R=batch_R,
                                               T=batch_T)
            all_silhouettes.append(
                sil_renderer(self._final_mesh,
                             device=self.device,
                             cameras=t_cameras).detach().cpu()[..., -1])

            if mode == "predict":
                all_images.append(
                    torch.clamp(
                        image_renderer(self._final_mesh,
                                       device=self.device,
                                       cameras=t_cameras),
                        0,
                        1,
                    ).detach().cpu()[..., :3])
            torch.cuda.empty_cache()
        renders = dict(
            silhouettes=torch.cat(all_silhouettes).unsqueeze(-1).permute(
                0, 3, 1, 2),
            images=torch.cat(all_images) if all_images else [],
        )

        return renders
Exemplo n.º 16
0
    def test_ndc_grid_sample_rendering(self):
        """
        Use PyTorch3D point renderer to render a colored point cloud, then
        sample the image at the locations of the point projections with
        `ndc_grid_sample`. Finally, assert that the sampled colors are equal to the
        original point cloud colors.

        Note that, in order to ensure correctness, we use a nearest-neighbor
        assignment point renderer (i.e. no soft splatting).
        """

        # generate a bunch of 3D points on a regular grid lying in the z-plane
        n_grid_pts = 10
        grid_scale = 0.9
        z_plane = 2.0
        image_size = [128, 128]
        point_radius = 0.015
        n_pts = n_grid_pts * n_grid_pts
        pts = torch.stack(
            meshgrid_ij([torch.linspace(-grid_scale, grid_scale, n_grid_pts)] *
                        2, ),
            dim=-1,
        )
        pts = torch.cat([pts, z_plane * torch.ones_like(pts[..., :1])], dim=-1)
        pts = pts.reshape(1, n_pts, 3)

        # color the points randomly
        pts_colors = torch.rand(1, n_pts, 3)

        # make trivial rendering cameras
        cameras = PerspectiveCameras(
            R=eyes(dim=3, N=1),
            device=pts.device,
            T=torch.zeros(1, 3, dtype=torch.float32, device=pts.device),
        )

        # render the point cloud
        pcl = Pointclouds(points=pts, features=pts_colors)
        renderer = NearestNeighborPointsRenderer(
            rasterizer=PointsRasterizer(
                cameras=cameras,
                raster_settings=PointsRasterizationSettings(
                    image_size=image_size,
                    radius=point_radius,
                    points_per_pixel=1,
                ),
            ),
            compositor=AlphaCompositor(),
        )
        im_render = renderer(pcl)

        # sample the render at projected pts
        pts_proj = cameras.transform_points(pcl.points_padded())[..., :2]
        pts_colors_sampled = ndc_grid_sample(
            im_render,
            pts_proj,
            mode="nearest",
            align_corners=False,
        ).permute(0, 2, 1)

        # assert that the samples are the same as original points
        self.assertClose(pts_colors, pts_colors_sampled, atol=1e-4)
Exemplo n.º 17
0
def generate_eval_video_cameras(
    train_cameras,
    n_eval_cams: int = 100,
    trajectory_type: str = "figure_eight",
    trajectory_scale: float = 0.2,
    scene_center: Tuple[float, float, float] = (0.0, 0.0, 0.0),
    up: Tuple[float, float, float] = (0.0, 0.0, 1.0),
    focal_length: Optional[torch.FloatTensor] = None,
    principal_point: Optional[torch.FloatTensor] = None,
    time: Optional[torch.FloatTensor] = None,
    infer_up_as_plane_normal: bool = True,
    traj_offset: Optional[Tuple[float, float, float]] = None,
    traj_offset_canonical: Optional[Tuple[float, float, float]] = None,
) -> PerspectiveCameras:
    """
    Generate a camera trajectory rendering a scene from multiple viewpoints.

    Args:
        train_dataset: The training dataset object.
        n_eval_cams: Number of cameras in the trajectory.
        trajectory_type: The type of the camera trajectory. Can be one of:
            circular_lsq_fit: Camera centers follow a trajectory obtained
                by fitting a 3D circle to train_cameras centers.
                All cameras are looking towards scene_center.
            figure_eight: Figure-of-8 trajectory around the center of the
                central camera of the training dataset.
            trefoil_knot: Same as 'figure_eight', but the trajectory has a shape
                of a trefoil knot (https://en.wikipedia.org/wiki/Trefoil_knot).
            figure_eight_knot: Same as 'figure_eight', but the trajectory has a shape
                of a figure-eight knot
                (https://en.wikipedia.org/wiki/Figure-eight_knot_(mathematics)).
        trajectory_scale: The extent of the trajectory.
        up: The "up" vector of the scene (=the normal of the scene floor).
            Active for the `trajectory_type="circular"`.
        scene_center: The center of the scene in world coordinates which all
            the cameras from the generated trajectory look at.
    Returns:
        Dictionary of camera instances which can be used as the test dataset
    """
    if trajectory_type in ("figure_eight", "trefoil_knot",
                           "figure_eight_knot"):
        cam_centers = train_cameras.get_camera_center()
        # get the nearest camera center to the mean of centers
        mean_camera_idx = (((cam_centers -
                             cam_centers.mean(dim=0)[None])**2).sum(dim=1).min(
                                 dim=0).indices)
        # generate the knot trajectory in canonical coords
        if time is None:
            time = torch.linspace(0, 2 * math.pi,
                                  n_eval_cams + 1)[:n_eval_cams]
        else:
            assert time.numel() == n_eval_cams
        if trajectory_type == "trefoil_knot":
            traj = _trefoil_knot(time)
        elif trajectory_type == "figure_eight_knot":
            traj = _figure_eight_knot(time)
        elif trajectory_type == "figure_eight":
            traj = _figure_eight(time)
        else:
            raise ValueError(f"bad trajectory type: {trajectory_type}")
        traj[:, 2] -= traj[:, 2].max()

        # transform the canonical knot to the coord frame of the mean camera
        mean_camera = PerspectiveCameras(
            **{
                k: getattr(train_cameras, k)[[int(mean_camera_idx)]]
                for k in ("focal_length", "principal_point", "R", "T")
            })
        traj_trans = Scale(
            cam_centers.std(dim=0).mean() * trajectory_scale).compose(
                mean_camera.get_world_to_view_transform().inverse())

        if traj_offset_canonical is not None:
            traj_trans = traj_trans.translate(
                torch.FloatTensor(traj_offset_canonical)[None].to(traj))

        traj = traj_trans.transform_points(traj)

        plane_normal = _fit_plane(cam_centers)[:, 0]
        if infer_up_as_plane_normal:
            up = _disambiguate_normal(plane_normal, up)

    elif trajectory_type == "circular_lsq_fit":
        ### fit plane to the camera centers

        # get the center of the plane as the median of the camera centers
        cam_centers = train_cameras.get_camera_center()

        if time is not None:
            angle = time
        else:
            angle = torch.linspace(0, 2.0 * math.pi,
                                   n_eval_cams).to(cam_centers)

        fit = fit_circle_in_3d(
            cam_centers,
            angles=angle,
            offset=angle.new_tensor(traj_offset_canonical)
            if traj_offset_canonical is not None else None,
            up=angle.new_tensor(up),
        )
        traj = fit.generated_points

        # scalethe trajectory
        _t_mu = traj.mean(dim=0, keepdim=True)
        traj = (traj - _t_mu) * trajectory_scale + _t_mu

        plane_normal = fit.normal

        if infer_up_as_plane_normal:
            up = _disambiguate_normal(plane_normal, up)

    else:
        raise ValueError(f"Uknown trajectory_type {trajectory_type}.")

    if traj_offset is not None:
        traj = traj + torch.FloatTensor(traj_offset)[None].to(traj)

    # point all cameras towards the center of the scene
    R, T = look_at_view_transform(
        eye=traj,
        at=(scene_center, ),  # (1, 3)
        up=(up, ),  # (1, 3)
        device=traj.device,
    )

    # get the average focal length and principal point
    if focal_length is None:
        focal_length = train_cameras.focal_length.mean(dim=0).repeat(
            n_eval_cams, 1)
    if principal_point is None:
        principal_point = train_cameras.principal_point.mean(dim=0).repeat(
            n_eval_cams, 1)

    test_cameras = PerspectiveCameras(
        focal_length=focal_length,
        principal_point=principal_point,
        R=R,
        T=T,
        device=focal_length.device,
    )

    # _visdom_plot_scene(
    #     train_cameras,
    #     test_cameras,
    # )

    return test_cameras
Exemplo n.º 18
0
def main_function(experiment_directory, continue_from, iterations,
                  marching_cubes_resolution, regularize):

    device = torch.device('cuda:0')
    specs = ws.load_experiment_specifications(experiment_directory)

    print("Reconstruction from experiment description: \n" +
          ' '.join([str(elem) for elem in specs["Description"]]))

    data_source = specs["DataSource"]
    test_split_file = specs["TestSplit"]

    arch_encoder = __import__("lib.models." + specs["NetworkEncoder"],
                              fromlist=["ResNet"])
    arch_decoder = __import__("lib.models." + specs["NetworkDecoder"],
                              fromlist=["DeepSDF"])
    latent_size = specs["CodeLength"]

    encoder = arch_encoder.ResNet(latent_size,
                                  specs["Depth"],
                                  norm_type=specs["NormType"]).cuda()
    decoder = arch_decoder.DeepSDF(latent_size, **specs["NetworkSpecs"]).cuda()

    encoder = torch.nn.DataParallel(encoder)
    decoder = torch.nn.DataParallel(decoder)

    print("testing with {} GPU(s)".format(torch.cuda.device_count()))

    num_samp_per_scene = specs["SamplesPerScene"]
    with open(test_split_file, "r") as f:
        test_split = json.load(f)

    sdf_dataset_test = lib.data.RGBA2SDF(data_source,
                                         test_split,
                                         num_samp_per_scene,
                                         is_train=False,
                                         num_views=specs["NumberOfViews"])
    torch.manual_seed(int(time.time() * 1000.0))
    sdf_loader_test = data_utils.DataLoader(
        sdf_dataset_test,
        batch_size=1,
        shuffle=True,
        num_workers=1,
        drop_last=False,
    )

    num_scenes = len(sdf_loader_test)
    print("There are {} scenes".format(num_scenes))

    print('Loading epoch "{}"'.format(continue_from))

    ws.load_model_parameters(experiment_directory, continue_from, encoder,
                             decoder)
    encoder.eval()

    optimization_meshes_dir = os.path.join(args.experiment_directory,
                                           ws.reconstructions_subdir,
                                           str(continue_from))

    if not os.path.isdir(optimization_meshes_dir):
        os.makedirs(optimization_meshes_dir)

    for sdf_data, image, intrinsic, extrinsic, name in sdf_loader_test:

        out_name = name[0].split("/")[-1]
        # store input stuff
        image_filename = os.path.join(optimization_meshes_dir, out_name,
                                      "input.png")
        # skip if it is already there
        if os.path.exists(os.path.dirname(image_filename)):
            print(name[0], " exists already ")
            continue
        print('Reconstructing {}...'.format(out_name))

        if not os.path.exists(os.path.dirname(image_filename)):
            os.makedirs(os.path.dirname(image_filename))

        image_export = 255 * image[0].permute(1, 2, 0).cpu().numpy()
        imageio.imwrite(image_filename, image_export.astype(np.uint8))

        image_filename = os.path.join(optimization_meshes_dir, out_name,
                                      "input_silhouette.png")
        image_export = 255 * image[0].permute(1, 2, 0).cpu().numpy()[..., 3]
        imageio.imwrite(image_filename, image_export.astype(np.uint8))

        # get latent code from image
        latent = encoder(image)
        # get estimated mesh
        verts, faces, samples, next_indices = lib.mesh.create_mesh(
            decoder, latent, N=marching_cubes_resolution, output_mesh=True)

        # store raw output
        mesh_filename = os.path.join(optimization_meshes_dir, out_name,
                                     "predicted.ply")
        lib.mesh.write_verts_faces_to_file(verts, faces, mesh_filename)

        verts_dr = torch.tensor(verts[None, :, :].copy(),
                                dtype=torch.float32,
                                requires_grad=False).cuda()
        faces_dr = torch.tensor(faces[None, :, :].copy()).cuda()

        IMG_SIZE = image.shape[-1]
        K_cuda = torch.tensor(intrinsic[:, 0:3, 0:3]).float().cuda()
        R_cuda = torch.tensor(extrinsic[:, 0:3,
                                        0:3]).float().cuda().permute(0, 2, 1)
        t_cuda = torch.tensor(extrinsic[:, 0:3, 3]).float().cuda()
        lights = PointLights(device=device, location=[[0.0, 0.0, 3.0]])
        cameras = PerspectiveCameras(device=device,
                                     focal_length=-K_cuda[:, 0, 0] /
                                     K_cuda[:, 0, 2],
                                     image_size=((IMG_SIZE, IMG_SIZE), ),
                                     R=R_cuda,
                                     T=t_cuda)
        raster_settings = RasterizationSettings(
            image_size=IMG_SIZE,
            blur_radius=0.000001,
            faces_per_pixel=1,
        )
        raster_settings_soft = RasterizationSettings(
            image_size=IMG_SIZE,
            blur_radius=np.log(1. / 1e-4 - 1.) * 1e-5,
            faces_per_pixel=25,
        )

        # instantiate renderers
        silhouette_renderer = MeshRenderer(rasterizer=MeshRasterizer(
            cameras=cameras, raster_settings=raster_settings_soft),
                                           shader=SoftSilhouetteShader())
        depth_renderer = MeshRasterizer(cameras=cameras,
                                        raster_settings=raster_settings)
        renderer = Renderer(silhouette_renderer,
                            depth_renderer,
                            image_size=IMG_SIZE)

        meshes = Meshes(verts_dr, faces_dr)
        verts_shape = meshes.verts_packed().shape
        verts_rgb = torch.full([1, verts_shape[0], 3],
                               0.5,
                               device=device,
                               requires_grad=False)
        meshes.textures = TexturesVertex(verts_features=verts_rgb)

        with torch.no_grad():

            normal_out, silhouette_out = renderer(meshes_world=meshes,
                                                  cameras=cameras,
                                                  lights=lights)

            image_out_export = 255 * silhouette_out.detach().cpu().numpy()[0]
            image_out_filename = os.path.join(optimization_meshes_dir,
                                              out_name,
                                              "predicted_silhouette.png")
            imageio.imwrite(image_out_filename,
                            image_out_export.astype(np.uint8))

            image_out_export = 255 * normal_out.detach().cpu().numpy()[0]
            image_out_filename = os.path.join(optimization_meshes_dir,
                                              out_name, "predicted.png")
            imageio.imwrite(image_out_filename,
                            image_out_export.astype(np.uint8))

        # load ground truth mesh for metrics
        mesh_filename = os.path.join(
            data_source, name[0].replace("samples", "meshes") + ".obj")

        mesh = trimesh.load(mesh_filename)
        vertices = torch.tensor(mesh.vertices).float().cuda()
        faces = torch.tensor(mesh.faces).float().cuda()

        vertices = vertices.unsqueeze(0)
        faces = faces.unsqueeze(0)
        meshes_gt = Meshes(vertices, faces)

        with torch.no_grad():
            normal_tgt, _ = renderer(meshes_world=meshes_gt,
                                     cameras=cameras,
                                     lights=lights)

        latent_for_optim = torch.tensor(latent, requires_grad=True)
        lr = 5e-5
        optimizer = torch.optim.Adam([latent_for_optim], lr=lr)

        decoder.eval()

        log_silhouette = []
        log_latent = []
        log_chd = []
        log_nc = []

        for e in range(iterations + 1):

            optimizer.zero_grad()

            # first create mesh
            verts, faces, samples, next_indices = lib.mesh.create_mesh_optim_fast(
                samples,
                next_indices,
                decoder,
                latent_for_optim,
                N=marching_cubes_resolution)

            # now assemble loss function
            xyz_upstream = torch.tensor(verts.astype(float),
                                        requires_grad=True,
                                        dtype=torch.float32,
                                        device=device)
            faces_upstream = torch.tensor(faces.astype(float),
                                          requires_grad=False,
                                          dtype=torch.float32,
                                          device=device)

            meshes_dr = Meshes(xyz_upstream.unsqueeze(0),
                               faces_upstream.unsqueeze(0))
            verts_shape = meshes_dr.verts_packed().shape
            verts_rgb = torch.full([1, verts_shape[0], 3],
                                   0.5,
                                   device=device,
                                   requires_grad=False)
            meshes_dr.textures = TexturesVertex(verts_features=verts_rgb)

            normal, silhouette = renderer(meshes_world=meshes_dr,
                                          cameras=cameras,
                                          lights=lights)
            # compute loss
            loss_silhouette = (torch.abs(silhouette -
                                         image[:, 3].cuda())).mean()

            # now store upstream gradients
            loss_silhouette.backward()
            dL_dx_i = xyz_upstream.grad
            # take care of weird stuff possibly happening
            dL_dx_i[torch.isnan(dL_dx_i)] = 0

            # log stuff
            with torch.no_grad():
                log_silhouette.append(loss_silhouette.detach().cpu().numpy())

                meshes_gt_pts = sample_points_from_meshes(meshes_gt)
                meshes_dr_pts = sample_points_from_meshes(meshes_dr)
                metric_chd, _ = chamfer_distance(meshes_gt_pts, meshes_dr_pts)
                log_chd.append(metric_chd.detach().cpu().numpy())

                log_nc.append(compute_normal_consistency(normal_tgt, normal))

                log_latent.append(
                    torch.mean(
                        (latent_for_optim).pow(2)).detach().cpu().numpy())

            # use vertices to compute full backward pass
            optimizer.zero_grad()
            xyz = torch.tensor(verts.astype(float),
                               requires_grad=True,
                               dtype=torch.float32,
                               device=torch.device('cuda:0'))
            latent_inputs = latent_for_optim.expand(xyz.shape[0], -1)
            #first compute normals
            pred_sdf = decoder(latent_inputs, xyz)
            loss_normals = torch.sum(pred_sdf)
            loss_normals.backward(retain_graph=True)
            normals = xyz.grad / torch.norm(xyz.grad, 2, 1).unsqueeze(-1)
            # now assemble inflow derivative
            optimizer.zero_grad()
            dL_ds_i = -torch.matmul(dL_dx_i.unsqueeze(1),
                                    normals.unsqueeze(-1)).squeeze(-1)
            # finally assemble full backward pass
            loss_backward = torch.sum(
                dL_ds_i * pred_sdf) + regularize * torch.mean(
                    (latent_for_optim).pow(2))
            loss_backward.backward()
            # and update params
            optimizer.step()

        # store all
        with torch.no_grad():
            verts, faces, samples, next_indices = lib.mesh.create_mesh_optim_fast(
                samples,
                next_indices,
                decoder,
                latent_for_optim,
                N=marching_cubes_resolution)
            mesh_filename = os.path.join(optimization_meshes_dir, out_name,
                                         "refined.ply")
            lib.mesh.write_verts_faces_to_file(verts, faces, mesh_filename)
            xyz_upstream = torch.tensor(verts.astype(float),
                                        requires_grad=True,
                                        dtype=torch.float32,
                                        device=device)
            faces_upstream = torch.tensor(faces.astype(float),
                                          requires_grad=False,
                                          dtype=torch.float32,
                                          device=device)

            meshes_dr = Meshes(xyz_upstream.unsqueeze(0),
                               faces_upstream.unsqueeze(0))
            verts_shape = meshes_dr.verts_packed().shape
            verts_rgb = torch.full([1, verts_shape[0], 3],
                                   0.5,
                                   device=device,
                                   requires_grad=False)
            meshes_dr.textures = TexturesVertex(verts_features=verts_rgb)

            normal, silhouette = renderer(meshes_world=meshes_dr,
                                          cameras=cameras,
                                          lights=lights)

            image_out_export = 255 * silhouette.detach().cpu().numpy()[0]
            image_out_filename = os.path.join(optimization_meshes_dir,
                                              out_name,
                                              "refined_silhouette.png")
            imageio.imwrite(image_out_filename,
                            image_out_export.astype(np.uint8))
            image_out_export = 255 * normal.detach().cpu().numpy()[0]
            image_out_filename = os.path.join(optimization_meshes_dir,
                                              out_name, "refined.png")
            imageio.imwrite(image_out_filename,
                            image_out_export.astype(np.uint8))

        log_filename = os.path.join(optimization_meshes_dir, out_name,
                                    "log_silhouette.npy")
        np.save(log_filename, log_silhouette)
        log_filename = os.path.join(optimization_meshes_dir, out_name,
                                    "log_chd.npy")
        np.save(log_filename, log_chd)
        log_filename = os.path.join(optimization_meshes_dir, out_name,
                                    "log_nc.npy")
        np.save(log_filename, log_nc)

        compute_normal_consistency(normal_tgt, normal)

        log_filename = os.path.join(optimization_meshes_dir, out_name,
                                    "log_latent.npy")
        np.save(log_filename, log_latent)
        print('Done with refinement.')
        print('Improvement in CHD {:.2f} %'.format(
            100 * (log_chd[0] - log_chd[-1]) / log_chd[0]))
        print('Improvement in NC {:.2f} %'.format(
            100 * (log_nc[-1] - log_nc[0]) / log_nc[0]))
Exemplo n.º 19
0
def batch_render(
    verts,
    faces,
    faces_per_pixel=10,
    K=None,
    rot=None,
    trans=None,
    colors=None,
    color=(0.53, 0.53, 0.8),  # light_purple
    ambient_col=0.5,
    specular_col=0.2,
    diffuse_col=0.3,
    face_colors=None,
    # color = (0.74117647, 0.85882353, 0.65098039),  # light_blue
    image_sizes=None,
    out_res=512,
    bin_size=0,
    shading="soft",
    mode="rgb",
    blend_gamma=1e-4,
    min_depth=None,
):
    device = torch.device("cuda:0")
    K = K.to(device)
    width, height = image_sizes[0]
    out_size = int(max(image_sizes[0]))
    raster_settings = RasterizationSettings(
        image_size=out_size,
        blur_radius=0.0,
        faces_per_pixel=faces_per_pixel,
        bin_size=bin_size,
    )

    fx = K[:, 0, 0]
    fy = K[:, 1, 1]
    focals = torch.stack([fx, fy], 1)
    px = K[:, 0, 2]
    py = K[:, 1, 2]
    principal_point = torch.stack([width - px, height - py], 1)
    if rot is None:
        rot = torch.eye(3).unsqueeze(0).to(device)
    if trans is None:
        trans = torch.zeros(3).unsqueeze(0).to(device)
    cameras = PerspectiveCameras(
        device=device,
        focal_length=focals,
        principal_point=principal_point,
        image_size=[(out_size, out_size) for _ in range(len(verts))],
        R=rot,
        T=trans,
    )
    if mode == "rgb":

        lights = PointLights(device=device, location=[[0.0, 0.0, -3.0]])
        lights = DirectionalLights(
            device=device,
            direction=((0.6, -0.6, -0.6), ),
            ambient_color=((ambient_col, ambient_col, ambient_col), ),
            diffuse_color=((diffuse_col, diffuse_col, diffuse_col), ),
            specular_color=((specular_col, specular_col, specular_col), ),
        )
        if shading == "soft":
            shader = SoftPhongShader(device=device,
                                     cameras=cameras,
                                     lights=lights)
        elif shading == "hard":
            shader = HardPhongShader(device=device,
                                     cameras=cameras,
                                     lights=lights)
        else:
            raise ValueError(
                f"Shading {shading} for mode rgb not in [sort|hard]")
    elif mode == "silh":
        blend_params = BlendParams(sigma=1e-4, gamma=1e-4)
        shader = SoftSilhouetteShader(blend_params=blend_params)
    elif shading == "faceidx":
        shader = FaceIdxShader()
    elif (mode == "facecolor") and (shading == "hard"):
        shader = FaceColorShader(face_colors=face_colors)
    elif (mode == "facecolor") and (shading == "soft"):
        shader = SoftFaceColorShader(face_colors=face_colors,
                                     blend_gamma=blend_gamma)
    else:
        raise ValueError(
            f"Unhandled mode {mode} and shading {shading} combination")

    renderer = MeshRenderer(
        rasterizer=MeshRasterizer(cameras=cameras,
                                  raster_settings=raster_settings),
        shader=shader,
    )
    if min_depth is not None:
        verts = torch.cat([verts[:, :, :2], verts[:, :, 2:].clamp(min_depth)],
                          2)
    if mode == "rgb":
        if colors is None:
            colors = get_colors(verts, color)
        tex = textures.TexturesVertex(verts_features=colors)

        meshes = Meshes(verts=verts, faces=faces, textures=tex)
    elif mode in ["silh", "facecolor"]:
        meshes = Meshes(verts=verts, faces=faces)
    else:
        raise ValueError(f"Render mode {mode} not in [rgb|silh]")

    square_images = renderer(meshes, cameras=cameras)
    square_images = torch.flip(square_images, (1, 2))
    height_off = abs(int(width - height))
    if width > height:
        images = square_images[:, height_off:, :]
    else:
        images = square_images[:, :, height_off:]
    return images
Exemplo n.º 20
0
    def compute(self, points: torch.Tensor, sdf: torch.Tensor,
                mesh_gt: Meshes):
        """
        Rasterize mesh faces from an far camera facing the origin,
        transform the predicted points position to camera view and project to get the normalized image coordinates
        The number of points on the zbuf at the image coordinates that are larger than the predicted points
        determines the sign of sdf
        """
        assert (points.ndim == 2 and points.shape[-1] == 3)
        device = points.device
        faces_per_pixel = 4
        with torch.autograd.no_grad():
            # a point that is definitely outside the mesh as camera center
            ray0 = torch.tensor([2, 2, 2], device=device,
                                dtype=points.dtype).view(1, 3)
            R, T = look_at_view_transform(eye=ray0,
                                          at=((0, 0, 0), ),
                                          up=((0, 0, 1), ))
            cameras = PerspectiveCameras(R=R, T=T, device=device)
            rasterizer = MeshRasterizer(cameras=cameras,
                                        raster_settings=RasterizationSettings(
                                            faces_per_pixel=faces_per_pixel, ))
            fragments = rasterizer(mesh_gt)

            z_predicted = cameras.get_world_to_view_transform(
            ).transform_points(points=points.unsqueeze(0))[..., -1:]
            # normalized pixel (top-left smallest values)
            screen_xy = -cameras.transform_points(points.unsqueeze(0))[..., :2]
            outside_screen = (screen_xy.abs() > 1.0).any(dim=-1)

            # pix_to_face, zbuf, bary_coords, dists
            assert (fragments.zbuf.shape[-1] == faces_per_pixel)
            zbuf = torch.nn.functional.grid_sample(
                fragments.zbuf.permute(0, 3, 1, 2),
                screen_xy.clamp(-1.0, 1.0).view(1, -1, 1, 2),
                align_corners=False,
                mode='nearest')
            zbuf[outside_screen.unsqueeze(1).expand(-1, zbuf.shape[1],
                                                    -1)] = -1.0
            sign = (((zbuf > z_predicted).sum(dim=1) %
                     2) == 0).type_as(points).view(screen_xy.shape[1])
            sign = sign * 2 - 1

        pcls = PointClouds3D(points.unsqueeze(0)).to(device=device)

        points_first_idx = pcls.cloud_to_packed_first_idx()
        max_points = pcls.num_points_per_cloud().max().item()

        # packed representation for faces
        verts_packed = mesh_gt.verts_packed()
        faces_packed = mesh_gt.faces_packed()
        tris = verts_packed[faces_packed]  # (T, 3, 3)
        tris_first_idx = mesh_gt.mesh_to_faces_packed_first_idx()
        max_tris = mesh_gt.num_faces_per_mesh().max().item()

        # point to face distance: shape (P,)
        point_to_face = point_face_distance(points, points_first_idx, tris,
                                            tris_first_idx, max_points)
        point_to_face = sign * torch.sqrt(eps_sqrt(point_to_face))
        loss = (point_to_face - sdf)**2
        return loss
Exemplo n.º 21
0
def get_nerf_datasets(
    dataset_name: str,  # 'lego | fern'
    image_size: Tuple[int, int],
    data_root: str = DEFAULT_DATA_ROOT,
    autodownload: bool = True,
) -> Tuple[Dataset, Dataset, Dataset]:
    """
    Obtains the training and validation dataset object for a dataset specified
    with the `dataset_name` argument.

    Args:
        dataset_name: The name of the dataset to load.
        image_size: A tuple (height, width) denoting the sizes of the loaded dataset images.
        data_root: The root folder at which the data is stored.
        autodownload: Auto-download the dataset files in case they are missing.

    Returns:
        train_dataset: The training dataset object.
        val_dataset: The validation dataset object.
        test_dataset: The testing dataset object.
    """

    if dataset_name not in ALL_DATASETS:
        raise ValueError(
            f"'{dataset_name}'' does not refer to a known dataset.")

    print(f"Loading dataset {dataset_name}, image size={str(image_size)} ...")

    cameras_path = os.path.join(data_root, dataset_name + ".pth")
    image_path = cameras_path.replace(".pth", ".png")

    if autodownload and any(not os.path.isfile(p)
                            for p in (cameras_path, image_path)):
        # Automatically download the data files if missing.
        download_data((dataset_name, ), data_root=data_root)

    train_data = torch.load(cameras_path)
    n_cameras = train_data["cameras"]["R"].shape[0]

    _image_max_image_pixels = Image.MAX_IMAGE_PIXELS
    Image.MAX_IMAGE_PIXELS = None  # The dataset image is very large ...
    images = torch.FloatTensor(np.array(Image.open(image_path))) / 255.0
    images = torch.stack(torch.chunk(images, n_cameras, dim=0))[..., :3]
    Image.MAX_IMAGE_PIXELS = _image_max_image_pixels

    scale_factors = [
        s_new / s for s, s_new in zip(images.shape[1:3], image_size)
    ]
    if abs(scale_factors[0] - scale_factors[1]) > 1e-3:
        raise ValueError(
            "Non-isotropic scaling is not allowed. Consider changing the 'image_size' argument."
        )
    scale_factor = sum(scale_factors) * 0.5

    if scale_factor != 1.0:
        print(f"Rescaling dataset (factor={scale_factor})")
        images = torch.nn.functional.interpolate(
            images.permute(0, 3, 1, 2),
            size=tuple(image_size),
            mode="bilinear",
        ).permute(0, 2, 3, 1)

    cameras = [
        PerspectiveCameras(
            **{k: v[cami][None]
               for k, v in train_data["cameras"].items()}).to("cpu")
        for cami in range(n_cameras)
    ]

    train_idx, val_idx, test_idx = train_data["split"]

    train_dataset, val_dataset, test_dataset = [
        ListDataset([{
            "image": images[i],
            "camera": cameras[i],
            "camera_idx": int(i)
        } for i in idx]) for idx in [train_idx, val_idx, test_idx]
    ]

    return train_dataset, val_dataset, test_dataset
Exemplo n.º 22
0
                         specular_color=((0.0, 0.0, 0.0), ))
    RRy, TTy = look_at_view_transform(dist=8,
                                      elev=0,
                                      azim=yaw_dim,
                                      up=((0, 1, 0), ),
                                      device=device)

    TTx = TTy[:args.pitch]
    RRx = get_R_matrix(azim=pitch_dim, axis="Rx")

    Rtotal = torch.cat([RRy, RRx], dim=0)
    Ttotal = torch.cat([TTy, TTx], dim=0)

    cameras = PerspectiveCameras(device=device,
                                 focal_length=4500,
                                 principal_point=((512, 512), ),
                                 R=Rtotal,
                                 T=Ttotal,
                                 image_size=((1024, 1024), ))

    if num_views != 1:
        camera = PerspectiveCameras(device=device,
                                    focal_length=4500,
                                    principal_point=((512, 512), ),
                                    R=Rtotal[None, 1, ...],
                                    T=Ttotal[None, 1, ...],
                                    image_size=((1024, 1024), ))
    else:
        camera = PerspectiveCameras(device=device,
                                    focal_length=4500,
                                    principal_point=((512, 512), ),
                                    R=Rtotal,
Exemplo n.º 23
0
def generate_eval_video_cameras(
    train_dataset,
    n_eval_cams: int = 100,
    trajectory_type: str = "figure_eight",
    trajectory_scale: float = 0.2,
    scene_center: Tuple[float, float, float] = (0.0, 0.0, 0.0),
    up: Tuple[float, float, float] = (0.0, 0.0, 1.0),
) -> dict:
    """
    Generate a camera trajectory for visualizing a NeRF model.

    Args:
        train_dataset: The training dataset object.
        n_eval_cams: Number of cameras in the trajectory.
        trajectory_type: The type of the camera trajectory. Can be one of:
            circular: Rotating around the center of the scene at a fixed radius.
            figure_eight: Figure-of-8 trajectory around the center of the
                central camera of the training dataset.
            trefoil_knot: Same as 'figure_eight', but the trajectory has a shape
                of a trefoil knot (https://en.wikipedia.org/wiki/Trefoil_knot).
            figure_eight_knot: Same as 'figure_eight', but the trajectory has a shape
                of a figure-eight knot
                (https://en.wikipedia.org/wiki/Figure-eight_knot_(mathematics)).
        trajectory_scale: The extent of the trajectory.
        up: The "up" vector of the scene (=the normal of the scene floor).
            Active for the `trajectory_type="circular"`.
        scene_center: The center of the scene in world coordinates which all
            the cameras from the generated trajectory look at.
    Returns:
        Dictionary of camera instances which can be used as the test dataset
    """
    if trajectory_type in ("figure_eight", "trefoil_knot", "figure_eight_knot"):
        cam_centers = torch.cat(
            [e["camera"].get_camera_center() for e in train_dataset]
        )
        # get the nearest camera center to the mean of centers
        mean_camera_idx = (
            ((cam_centers - cam_centers.mean(dim=0)[None]) ** 2)
            .sum(dim=1)
            .min(dim=0)
            .indices
        )
        # generate the knot trajectory in canonical coords
        time = torch.linspace(0, 2 * math.pi, n_eval_cams + 1)[:n_eval_cams]
        if trajectory_type == "trefoil_knot":
            traj = _trefoil_knot(time)
        elif trajectory_type == "figure_eight_knot":
            traj = _figure_eight_knot(time)
        elif trajectory_type == "figure_eight":
            traj = _figure_eight(time)
        traj[:, 2] -= traj[:, 2].max()

        # transform the canonical knot to the coord frame of the mean camera
        traj_trans = (
            train_dataset[mean_camera_idx]["camera"]
            .get_world_to_view_transform()
            .inverse()
        )
        traj_trans = traj_trans.scale(cam_centers.std(dim=0).mean() * trajectory_scale)
        traj = traj_trans.transform_points(traj)

    elif trajectory_type == "circular":
        cam_centers = torch.cat(
            [e["camera"].get_camera_center() for e in train_dataset]
        )

        # fit plane to the camera centers
        plane_mean = cam_centers.mean(dim=0)
        cam_centers_c = cam_centers - plane_mean[None]

        if up is not None:
            # us the up vector instead of the plane through the camera centers
            plane_normal = torch.FloatTensor(up)
        else:
            cov = (cam_centers_c.t() @ cam_centers_c) / cam_centers_c.shape[0]
            _, e_vec = torch.symeig(cov, eigenvectors=True)
            plane_normal = e_vec[:, 0]

        plane_dist = (plane_normal[None] * cam_centers_c).sum(dim=-1)
        cam_centers_on_plane = cam_centers_c - plane_dist[:, None] * plane_normal[None]

        cov = (
            cam_centers_on_plane.t() @ cam_centers_on_plane
        ) / cam_centers_on_plane.shape[0]
        _, e_vec = torch.symeig(cov, eigenvectors=True)
        traj_radius = (cam_centers_on_plane ** 2).sum(dim=1).sqrt().mean()
        angle = torch.linspace(0, 2.0 * math.pi, n_eval_cams)
        traj = traj_radius * torch.stack(
            (torch.zeros_like(angle), angle.cos(), angle.sin()), dim=-1
        )
        traj = traj @ e_vec.t() + plane_mean[None]

    else:
        raise ValueError(f"Unknown trajectory_type {trajectory_type}.")

    # point all cameras towards the center of the scene
    R, T = look_at_view_transform(
        eye=traj,
        at=(scene_center,),  # (1, 3)
        up=(up,),  # (1, 3)
        device=traj.device,
    )

    # get the average focal length and principal point
    focal = torch.cat([e["camera"].focal_length for e in train_dataset]).mean(dim=0)
    p0 = torch.cat([e["camera"].principal_point for e in train_dataset]).mean(dim=0)

    # assemble the dataset
    test_dataset = [
        {
            "image": None,
            "camera": PerspectiveCameras(
                focal_length=focal[None],
                principal_point=p0[None],
                R=R_[None],
                T=T_[None],
            ),
            "camera_idx": i,
        }
        for i, (R_, T_) in enumerate(zip(R, T))
    ]

    return test_dataset
Exemplo n.º 24
0
    def forward(self,
                verts,        # under general camera coordinate rXdYfZ,  N*V*3    
                faces,        # indices in verts to define traingles,    N*F*3
                verts_uvs,    # uv coordinate of corresponding verts,    N*V*2
                faces_uvs,    # indices in verts to define triangles,    N*F*3
                tex_image,    # under GCcd,                            N*H*W*3
                R,            # under GCcd,                              N*3*3 
                T,            # under GCcd,                              N*3
                f,            # in pixel/m,                              N*1
                C,            # camera center,                           N*2
                imgres,       # int
                lightLoc = None):
        
        assert verts.shape[0] == 1,\
            'with some issues in pytorch3D, render 1 mesh per forward'
        
        # only need to convert either R and T or verts, we choose R and T here
        if self.convertToPytorch3D:
            R = torch.matmul(self.GCcdToPytorch3D, R)
            T = torch.matmul(self.GCcdToPytorch3D, T.unsqueeze(-1)).squeeze(-1)
        
        # prepare textures and mesh to render
        tex = TexturesUV(
            verts_uvs = verts_uvs, 
            faces_uvs = faces_uvs, 
            maps = tex_image
        ) 
        mesh = Meshes(verts = verts, faces = faces, textures=tex)
        
        # Initialize a camera. The world coordinate is +Y up, +X left and +Z in. 
        cameras = PerspectiveCameras(
            focal_length=f,
            principal_point=C,
            R=R, 
            T=T, 
            image_size=((imgres,imgres),),
            device=self.device
        )

        # Define the settings for rasterization and shading. 
        raster_settings = RasterizationSettings(
            image_size=imgres, 
            blur_radius=0.0, 
            faces_per_pixel=1, 
        )

        # Create a simple renderer by composing a rasterizer and a shader.
        # The simple textured shader will interpolate the texture uv coordinates 
        # for each pixel, sample from a texture image. This renderer can
        # support lighting easily but we do not iimplement it.
        renderer = MeshRenderer(
            rasterizer=MeshRasterizer(
                cameras=cameras, 
                raster_settings=raster_settings
            ),
            shader=SimpleShader(
                device=self.device
            )
        )
        
        # render the rendered image(s)
        images = renderer(mesh)
        
        return images