Beispiel #1
0
    def _voxel_size_translation_to_transform(
        self,
        voxel_size: torch.Tensor,
        volume_translation: torch.Tensor,
        batch_size: int,
    ) -> Transform3d:
        """
        Converts the `voxel_size` and `volume_translation` constructor arguments
        to the internal `Transform3d` object `local_to_world_transform`.
        """
        volume_size_zyx = self.get_grid_sizes().float()
        volume_size_xyz = volume_size_zyx[:, [2, 1, 0]]

        # x_local = (
        #       (x_world + volume_translation) / (0.5 * voxel_size)
        #   ) / (volume_size - 1)

        # x_world = (
        #       x_local * (volume_size - 1) * 0.5 * voxel_size
        #   ) - volume_translation

        local_to_world_transform = Scale(
            (volume_size_xyz - 1) * voxel_size * 0.5,
            device=self.device).translate(-volume_translation)

        return local_to_world_transform
Beispiel #2
0
    def normalize_to_box_(self):
        """
        center and scale the point clouds to a unit cube,
        Returns:
            normalizing_trans (Transform3D): Transform3D used to normalize the pointclouds
        """
        # (B,3,2)
        boxMinMax = self.get_bounding_boxes()
        boxCenter = boxMinMax.sum(dim=-1) / 2
        # (B,)
        boxRange, _ = (boxMinMax[:, :, 1] - boxMinMax[:, :, 0]).max(dim=-1)
        if boxRange == 0:
            boxRange = 1

        # center and scale the point clouds, likely faster than calling obj2world_trans directly?
        pointOffsets = torch.repeat_interleave(-boxCenter,
                                               self.num_points_per_cloud(),
                                               dim=0)
        self.offset_(pointOffsets)
        self.scale_(1 / boxRange)

        # update obj2world_trans
        normalizing_trans = Translate(-boxCenter).compose(Scale(
            1 / boxRange)).to(device=self.device)
        self.obj2world_trans = normalizing_trans.inverse().compose(
            self.obj2world_trans)
        return normalizing_trans
Beispiel #3
0
    def normalize_to_sphere_(self):
        """
        Center and scale the point clouds to a unit sphere
        Returns: normalizing_trans (Transform3D)
        """
        # (B,3,2)
        boxMinMax = self.get_bounding_boxes()
        boxCenter = boxMinMax.sum(dim=-1) / 2
        # (B,)
        boxRange, _ = (boxMinMax[:, :, 1] - boxMinMax[:, :, 0]).max(dim=-1)
        if boxRange == 0:
            boxRange = 1

        # center and scale the point clouds, likely faster than calling obj2world_trans directly?
        pointOffsets = torch.repeat_interleave(-boxCenter,
                                               self.num_points_per_cloud(),
                                               dim=0)
        self.offset_(pointOffsets)
        # (P)
        norms = torch.norm(self.points_packed(), dim=-1)
        # List[(Pi)]
        norms = torch.split(norms, self.num_points_per_cloud())
        # (N)
        scale = torch.stack([x.max() for x in norms], dim=0)
        self.scale_(1 / eps_denom(scale))
        normalizing_trans = Translate(-boxCenter).compose(
            Scale(1 / eps_denom(scale))).to(device=self.device)
        self.obj2world_trans = normalizing_trans.inverse().compose(
            self.obj2world_trans)
        return normalizing_trans
    def test_coord_transforms(self, num_volumes=3, num_channels=4, dtype=torch.float32):
        """
        Test the correctness of the conversion between the internal
        Transform3D Volumes._local_to_world_transform and the initialization
        from the translation and voxel_size.
        """

        device = torch.device("cuda:0")

        # try for 10 sets of different random sizes/centers/voxel_sizes
        for _ in range(10):

            size = torch.randint(high=10, size=(3,), low=3).tolist()

            densities = torch.randn(
                size=[num_volumes, num_channels, *size],
                device=device,
                dtype=torch.float32,
            )

            # init the transformation params
            volume_translation = torch.randn(num_volumes, 3)
            voxel_size = torch.rand(num_volumes, 3) * 3.0 + 0.5

            # get the corresponding Transform3d object
            local_offset = torch.tensor(list(size), dtype=torch.float32, device=device)[
                [2, 1, 0]
            ][None].repeat(num_volumes, 1)
            local_to_world_transform = (
                Scale(0.5 * local_offset - 0.5, device=device)
                .scale(voxel_size)
                .translate(-volume_translation)
            )

            # init the volume structures with the scale and translation,
            # then get the coord grid in world coords
            v_trans_vs = Volumes(
                densities=densities,
                voxel_size=voxel_size,
                volume_translation=volume_translation,
            )
            grid_rot_trans_vs = v_trans_vs.get_coord_grid(world_coordinates=True)

            # map the default local coords to the world coords
            # with local_to_world_transform
            v_default = Volumes(densities=densities)
            grid_default_local = v_default.get_coord_grid(world_coordinates=False)
            grid_default_world = local_to_world_transform.transform_points(
                grid_default_local.view(num_volumes, -1, 3)
            ).view(num_volumes, *size, 3)

            # check that both grids are the same
            self.assertClose(grid_rot_trans_vs, grid_default_world, atol=1e-5)

            # check that the transformations are the same
            self.assertClose(
                v_trans_vs.get_local_to_world_coords_transform().get_matrix(),
                local_to_world_transform.get_matrix(),
                atol=1e-5,
            )
Beispiel #5
0
 def normalize_to_sphere_(self):
     """
     Center and scale the point clouds to a unit sphere
     Returns: normalizing_trans (Transform3D)
     """
     # packed offset
     center = torch.stack([x.mean(dim=0) for x in self.points_list()],
                          dim=0)
     center_packed = torch.repeat_interleave(-center,
                                             self.num_points_per_cloud(),
                                             dim=0)
     self.offset_(center_packed)
     # (P)
     norms = torch.norm(self.points_packed(), dim=-1)
     # List[(Pi)]
     norms = torch.split(norms, self.num_points_per_cloud())
     # (N)
     scale = torch.stack([x.max() for x in norms], dim=0)
     self.scale_(1 / eps_denom(scale))
     normalizing_trans = Translate(-center).compose(
         Scale(1 / eps_denom(scale))).to(device=self.device)
     self.obj2world_trans = normalizing_trans.inverse().compose(
         self.obj2world_trans)
     return normalizing_trans
Beispiel #6
0
def generate_eval_video_cameras(
    train_cameras,
    n_eval_cams: int = 100,
    trajectory_type: str = "figure_eight",
    trajectory_scale: float = 0.2,
    scene_center: Tuple[float, float, float] = (0.0, 0.0, 0.0),
    up: Tuple[float, float, float] = (0.0, 0.0, 1.0),
    focal_length: Optional[torch.FloatTensor] = None,
    principal_point: Optional[torch.FloatTensor] = None,
    time: Optional[torch.FloatTensor] = None,
    infer_up_as_plane_normal: bool = True,
    traj_offset: Optional[Tuple[float, float, float]] = None,
    traj_offset_canonical: Optional[Tuple[float, float, float]] = None,
) -> PerspectiveCameras:
    """
    Generate a camera trajectory rendering a scene from multiple viewpoints.

    Args:
        train_dataset: The training dataset object.
        n_eval_cams: Number of cameras in the trajectory.
        trajectory_type: The type of the camera trajectory. Can be one of:
            circular_lsq_fit: Camera centers follow a trajectory obtained
                by fitting a 3D circle to train_cameras centers.
                All cameras are looking towards scene_center.
            figure_eight: Figure-of-8 trajectory around the center of the
                central camera of the training dataset.
            trefoil_knot: Same as 'figure_eight', but the trajectory has a shape
                of a trefoil knot (https://en.wikipedia.org/wiki/Trefoil_knot).
            figure_eight_knot: Same as 'figure_eight', but the trajectory has a shape
                of a figure-eight knot
                (https://en.wikipedia.org/wiki/Figure-eight_knot_(mathematics)).
        trajectory_scale: The extent of the trajectory.
        up: The "up" vector of the scene (=the normal of the scene floor).
            Active for the `trajectory_type="circular"`.
        scene_center: The center of the scene in world coordinates which all
            the cameras from the generated trajectory look at.
    Returns:
        Dictionary of camera instances which can be used as the test dataset
    """
    if trajectory_type in ("figure_eight", "trefoil_knot",
                           "figure_eight_knot"):
        cam_centers = train_cameras.get_camera_center()
        # get the nearest camera center to the mean of centers
        mean_camera_idx = (((cam_centers -
                             cam_centers.mean(dim=0)[None])**2).sum(dim=1).min(
                                 dim=0).indices)
        # generate the knot trajectory in canonical coords
        if time is None:
            time = torch.linspace(0, 2 * math.pi,
                                  n_eval_cams + 1)[:n_eval_cams]
        else:
            assert time.numel() == n_eval_cams
        if trajectory_type == "trefoil_knot":
            traj = _trefoil_knot(time)
        elif trajectory_type == "figure_eight_knot":
            traj = _figure_eight_knot(time)
        elif trajectory_type == "figure_eight":
            traj = _figure_eight(time)
        else:
            raise ValueError(f"bad trajectory type: {trajectory_type}")
        traj[:, 2] -= traj[:, 2].max()

        # transform the canonical knot to the coord frame of the mean camera
        mean_camera = PerspectiveCameras(
            **{
                k: getattr(train_cameras, k)[[int(mean_camera_idx)]]
                for k in ("focal_length", "principal_point", "R", "T")
            })
        traj_trans = Scale(
            cam_centers.std(dim=0).mean() * trajectory_scale).compose(
                mean_camera.get_world_to_view_transform().inverse())

        if traj_offset_canonical is not None:
            traj_trans = traj_trans.translate(
                torch.FloatTensor(traj_offset_canonical)[None].to(traj))

        traj = traj_trans.transform_points(traj)

        plane_normal = _fit_plane(cam_centers)[:, 0]
        if infer_up_as_plane_normal:
            up = _disambiguate_normal(plane_normal, up)

    elif trajectory_type == "circular_lsq_fit":
        ### fit plane to the camera centers

        # get the center of the plane as the median of the camera centers
        cam_centers = train_cameras.get_camera_center()

        if time is not None:
            angle = time
        else:
            angle = torch.linspace(0, 2.0 * math.pi,
                                   n_eval_cams).to(cam_centers)

        fit = fit_circle_in_3d(
            cam_centers,
            angles=angle,
            offset=angle.new_tensor(traj_offset_canonical)
            if traj_offset_canonical is not None else None,
            up=angle.new_tensor(up),
        )
        traj = fit.generated_points

        # scalethe trajectory
        _t_mu = traj.mean(dim=0, keepdim=True)
        traj = (traj - _t_mu) * trajectory_scale + _t_mu

        plane_normal = fit.normal

        if infer_up_as_plane_normal:
            up = _disambiguate_normal(plane_normal, up)

    else:
        raise ValueError(f"Uknown trajectory_type {trajectory_type}.")

    if traj_offset is not None:
        traj = traj + torch.FloatTensor(traj_offset)[None].to(traj)

    # point all cameras towards the center of the scene
    R, T = look_at_view_transform(
        eye=traj,
        at=(scene_center, ),  # (1, 3)
        up=(up, ),  # (1, 3)
        device=traj.device,
    )

    # get the average focal length and principal point
    if focal_length is None:
        focal_length = train_cameras.focal_length.mean(dim=0).repeat(
            n_eval_cams, 1)
    if principal_point is None:
        principal_point = train_cameras.principal_point.mean(dim=0).repeat(
            n_eval_cams, 1)

    test_cameras = PerspectiveCameras(
        focal_length=focal_length,
        principal_point=principal_point,
        R=R,
        T=T,
        device=focal_length.device,
    )

    # _visdom_plot_scene(
    #     train_cameras,
    #     test_cameras,
    # )

    return test_cameras