Ejemplo n.º 1
0
    def test_save_load_with_normals(self):
        points = torch.tensor([[0, 0, 0], [0, 0, 1], [0, 1, 0], [1, 0, 0]],
                              dtype=torch.float32)
        normals = torch.tensor([[0, 1, 0], [1, 0, 0], [1, 4, 1], [1, 0, 0]],
                               dtype=torch.float32)
        features = torch.rand_like(points)

        for do_features, do_normals in itertools.product([True, False],
                                                         [True, False]):
            cloud = Pointclouds(
                points=[points],
                features=[features] if do_features else None,
                normals=[normals] if do_normals else None,
            )
            device = torch.device("cuda:0")

            io = IO()
            with NamedTemporaryFile(mode="w", suffix=".ply") as f:
                io.save_pointcloud(cloud.cuda(), f.name)
                f.flush()
                cloud2 = io.load_pointcloud(f.name, device=device)
            self.assertEqual(cloud2.device, device)
            cloud2 = cloud2.cpu()
            self.assertClose(cloud2.points_padded(), cloud.points_padded())
            if do_normals:
                self.assertClose(cloud2.normals_padded(),
                                 cloud.normals_padded())
            else:
                self.assertIsNone(cloud.normals_padded())
                self.assertIsNone(cloud2.normals_padded())
            if do_features:
                self.assertClose(cloud2.features_packed(), features)
            else:
                self.assertIsNone(cloud2.features_packed())
Ejemplo n.º 2
0
def eval_one_dir(exp_dir, n_pts=50000):
    """
    Function for one directory
    """
    device = torch.device('cuda:0')
    cfg = config.load_config(os.path.join(exp_dir, 'config.yaml'))
    dataset = config.create_dataset(cfg.data, mode='val')
    meshes_gt = dataset.get_meshes().to(device)
    val_gt_pts_file = os.path.join(cfg.data.data_dir, 'val%d.ply' % n_pts)
    if os.path.isfile(val_gt_pts_file):
        points, normals = np.split(read_ply(val_gt_pts_file), 2, axis=1)
        pcl_gt = Pointclouds(
            torch.from_numpy(points[None, ...]).float(),
            torch.from_numpy(normals[None, ...]).float()).to(device)
    else:
        pcl_gt = dataset.get_pointclouds(n_pts).to(device)
        trimesh.Trimesh(pcl_gt.points_packed().cpu().numpy(),
                        vertex_normals=pcl_gt.normals_packed().cpu().numpy(),
                        process=False).export(val_gt_pts_file,
                                              vertex_normal=True)

    # load vis directories
    vis_dir = os.path.join(exp_dir, 'vis')
    vis_files = sorted(get_filenames(vis_dir, '_mesh.ply'))
    iters = [int(os.path.basename(v).split('_')[0]) for v in vis_files]
    best_dict = defaultdict(lambda: float('inf'))
    vis_eval_csv = os.path.join(vis_dir, "evaluation_n%d.csv" % n_pts)
    if not os.path.isfile(vis_eval_csv):
        with open(os.path.join(vis_dir, "evaluation_n%d.csv" % n_pts),
                  "w") as f:
            fieldnames = ['mtime', 'it', 'chamfer_p', 'chamfer_n', 'pf_dist']
            writer = csv.DictWriter(f,
                                    fieldnames=fieldnames,
                                    restval="-",
                                    extrasaction="ignore")
            writer.writeheader()
            mtime0 = None
            for it, vis_file in zip(iters, vis_files):
                eval_dict = OrderedDict()
                mtime = os.path.getmtime(vis_file)
                if mtime0 is None:
                    mtime0 = mtime
                eval_dict['it'] = it
                eval_dict['mtime'] = mtime - mtime0
                val_pts_file = os.path.join(
                    vis_dir,
                    os.path.basename(vis_file).replace('_mesh',
                                                       '_val%d' % n_pts))
                if os.path.isfile(val_pts_file):
                    points, normals = np.split(read_ply(val_pts_file),
                                               2,
                                               axis=1)
                    points = torch.from_numpy(points).float().to(
                        device=device).view(1, -1, 3)
                    normals = torch.from_numpy(normals).float().to(
                        device=device).view(1, -1, 3)
                else:
                    mesh = trimesh.load(vis_file, process=False)
                    # points, normals = pcu.sample_mesh_poisson_disk(
                    #     mesh.vertices, mesh.faces,
                    #     mesh.vertex_normals.ravel().reshape(-1, 3), n_pts, use_geodesic_distance=True)
                    # p_idx = np.random.permutation(points.shape[0])[:n_pts]
                    # points = points[p_idx, ...]
                    # normals = normals[p_idx, ...]
                    # points = torch.from_numpy(points).float().to(
                    #     device=device).view(1, -1, 3)
                    # normals = torch.from_numpy(normals).float().to(
                    #     device=device).view(1, -1, 3)
                    meshes = Meshes(
                        torch.from_numpy(mesh.vertices[None, ...]).float(),
                        torch.from_numpy(mesh.faces[None,
                                                    ...]).float()).to(device)
                    points, normals = sample_points_from_meshes(
                        meshes, n_pts, return_normals=True)
                    trimesh.Trimesh(points.cpu().numpy()[0],
                                    vertex_normals=normals.cpu().numpy()[0],
                                    process=False).export(val_pts_file,
                                                          vertex_normal=True)
                pcl = Pointclouds(points, normals)
                chamfer_p, chamfer_n = chamfer_distance(
                    points,
                    pcl_gt.points_padded(),
                    x_normals=normals,
                    y_normals=pcl_gt.normals_padded(),
                )
                eval_dict['chamfer_p'] = chamfer_p.item()
                eval_dict['chamfer_n'] = chamfer_n.item()
                pf_dist = point_mesh_face_distance(meshes_gt, pcl)
                eval_dict['pf_dist'] = pf_dist.item()
                writer.writerow(eval_dict)
                for k, v in eval_dict.items():
                    if v < best_dict[k]:
                        best_dict[k] = v
                        print('best {} so far ({}): {:.4g}'.format(
                            k, vis_file, v))

    # generation dictories
    gen_dir = os.path.join(exp_dir, 'generation')
    if not os.path.isdir(gen_dir):
        return

    final_file = os.path.join(gen_dir, 'mesh.ply')
    val_pts_file = final_file[:-4] + '_val%d' % n_pts + '.ply'
    if not os.path.isfile(final_file):
        return

    gen_file_csv = os.path.join(gen_dir, "evaluation_n%d.csv" % n_pts)
    if not os.path.isfile(gen_file_csv):
        with open(os.path.join(gen_dir, "evaluation_n%d.csv" % n_pts),
                  "w") as f:
            fieldnames = ['chamfer_p', 'chamfer_n', 'pf_dist']
            writer = csv.DictWriter(f,
                                    fieldnames=fieldnames,
                                    restval="-",
                                    extrasaction="ignore")
            writer.writeheader()
            eval_dict = OrderedDict()
            mesh = trimesh.load(final_file)
            # points, normals = pcu.sample_mesh_poisson_disk(
            #     mesh.vertices, mesh.faces,
            #     mesh.vertex_normals.ravel().reshape(-1, 3), n_pts, use_geodesic_distance=True)
            # p_idx = np.random.permutation(points.shape[0])[:n_pts]
            # points = points[p_idx, ...]
            # normals = normals[p_idx, ...]
            # points = torch.from_numpy(points).float().to(
            #     device=device).view(1, -1, 3)
            # normals = torch.from_numpy(normals).float().to(
            #     device=device).view(1, -1, 3)
            meshes = Meshes(
                torch.from_numpy(mesh.vertices[None, ...]).float(),
                torch.from_numpy(mesh.faces[None, ...]).float()).to(device)
            points, normals = sample_points_from_meshes(meshes,
                                                        n_pts,
                                                        return_normals=True)
            trimesh.Trimesh(points.cpu().numpy()[0],
                            vertex_normals=normals.cpu().numpy()[0],
                            process=False).export(val_pts_file,
                                                  vertex_normal=True)
            pcl = Pointclouds(points, normals)
            chamfer_p, chamfer_n = chamfer_distance(
                points,
                pcl_gt.points_padded(),
                x_normals=normals,
                y_normals=pcl_gt.normals_padded(),
            )
            eval_dict['chamfer_p'] = chamfer_p.item()
            eval_dict['chamfer_n'] = chamfer_n.item()
            pf_dist = point_mesh_face_distance(meshes_gt, pcl)
            eval_dict['pf_dist'] = pf_dist.item()
            writer.writerow(eval_dict)
            for k, v in eval_dict.items():
                if v < best_dict[k]:
                    best_dict[k] = v
                    print('best {} so far ({}): {:.4g}'.format(
                        k, final_file, v))
Ejemplo n.º 3
0
    def forward(
        self,
        camera: CamerasBase,
        image_rgb: torch.Tensor,
        depth_map: torch.Tensor,
        fg_probability: torch.Tensor,
        frame_type: List[str],
        **kwargs,
    ) -> Dict[str, Any]:  # TODO: return a namedtuple or dataclass
        """
        Given a set of input source cameras images and depth maps, unprojects
        all RGBD maps to a colored point cloud and renders into the target views.

        Args:
            camera: A batch of `N` PyTorch3D cameras.
            image_rgb: A batch of `N` images of shape `(N, 3, H, W)`.
            depth_map: A batch of `N` depth maps of shape `(N, 1, H, W)`.
            fg_probability: A batch of `N` foreground probability maps
                of shape `(N, 1, H, W)`.
            frame_type: A list of `N` strings containing frame type indicators
                which specify target and source views.

        Returns:
            preds: A dict with the following fields:
                nvs_prediction: The rendered colors, depth and mask
                    of the target views.
                point_cloud: The point cloud of the scene. It's renders are
                    stored in `nvs_prediction`.
        """

        is_known = is_known_frame(frame_type)
        is_known_idx = torch.where(is_known)[0]

        mask_fg = (fg_probability > 0.5).type_as(image_rgb)

        point_cloud = get_rgbd_point_cloud(
            camera[is_known_idx],
            image_rgb[is_known_idx],
            depth_map[is_known_idx],
            mask_fg[is_known_idx],
        )

        pcl_size = int(point_cloud.num_points_per_cloud())
        if (self.max_points > 0) and (pcl_size > self.max_points):
            prm = torch.randperm(pcl_size)[:self.max_points]
            point_cloud = Pointclouds(
                point_cloud.points_padded()[:, prm, :],
                # pyre-fixme[16]: Optional type has no attribute `__getitem__`.
                features=point_cloud.features_padded()[:, prm, :],
            )

        is_target_idx = torch.where(~is_known)[0]

        depth_render, image_render, mask_render = [], [], []

        # render into target frames in a for loop to save memory
        for tgt_idx in is_target_idx:
            _image_render, _mask_render, _depth_render = render_point_cloud_pytorch3d(
                camera[int(tgt_idx)],
                point_cloud,
                render_size=(self.image_size, self.image_size),
                point_radius=1e-2,
                topk=10,
                bg_color=self.bg_color,
            )
            _image_render = _image_render.clamp(0.0, 1.0)
            # the mask is the set of pixels with opacity bigger than eps
            _mask_render = (_mask_render > 1e-4).float()

            depth_render.append(_depth_render)
            image_render.append(_image_render)
            mask_render.append(_mask_render)

        nvs_prediction = NewViewSynthesisPrediction(
            **{
                k: torch.cat(v, dim=0)
                for k, v in zip(
                    ["depth_render", "image_render", "mask_render"],
                    [depth_render, image_render, mask_render],
                )
            })

        preds = {
            "nvs_prediction": nvs_prediction,
            "point_cloud": point_cloud,
        }

        return preds
Ejemplo n.º 4
0
    def test_ndc_grid_sample_rendering(self):
        """
        Use PyTorch3D point renderer to render a colored point cloud, then
        sample the image at the locations of the point projections with
        `ndc_grid_sample`. Finally, assert that the sampled colors are equal to the
        original point cloud colors.

        Note that, in order to ensure correctness, we use a nearest-neighbor
        assignment point renderer (i.e. no soft splatting).
        """

        # generate a bunch of 3D points on a regular grid lying in the z-plane
        n_grid_pts = 10
        grid_scale = 0.9
        z_plane = 2.0
        image_size = [128, 128]
        point_radius = 0.015
        n_pts = n_grid_pts * n_grid_pts
        pts = torch.stack(
            meshgrid_ij([torch.linspace(-grid_scale, grid_scale, n_grid_pts)] *
                        2, ),
            dim=-1,
        )
        pts = torch.cat([pts, z_plane * torch.ones_like(pts[..., :1])], dim=-1)
        pts = pts.reshape(1, n_pts, 3)

        # color the points randomly
        pts_colors = torch.rand(1, n_pts, 3)

        # make trivial rendering cameras
        cameras = PerspectiveCameras(
            R=eyes(dim=3, N=1),
            device=pts.device,
            T=torch.zeros(1, 3, dtype=torch.float32, device=pts.device),
        )

        # render the point cloud
        pcl = Pointclouds(points=pts, features=pts_colors)
        renderer = NearestNeighborPointsRenderer(
            rasterizer=PointsRasterizer(
                cameras=cameras,
                raster_settings=PointsRasterizationSettings(
                    image_size=image_size,
                    radius=point_radius,
                    points_per_pixel=1,
                ),
            ),
            compositor=AlphaCompositor(),
        )
        im_render = renderer(pcl)

        # sample the render at projected pts
        pts_proj = cameras.transform_points(pcl.points_padded())[..., :2]
        pts_colors_sampled = ndc_grid_sample(
            im_render,
            pts_proj,
            mode="nearest",
            align_corners=False,
        ).permute(0, 2, 1)

        # assert that the samples are the same as original points
        self.assertClose(pts_colors, pts_colors_sampled, atol=1e-4)
Ejemplo n.º 5
0
def cleanup_eval_depth(
    point_cloud: Pointclouds,
    camera: CamerasBase,
    depth: torch.Tensor,
    mask: torch.Tensor,
    sigma: float = 0.01,
    image=None,
):

    ba, _, H, W = depth.shape

    pcl = point_cloud.points_padded()
    n_pts = point_cloud.num_points_per_cloud()
    pcl_mask = (
        torch.arange(pcl.shape[1], dtype=torch.int64, device=pcl.device)[None]
        < n_pts[:, None]
    ).type_as(pcl)

    pcl_proj = camera.transform_points(pcl, eps=1e-2)[..., :-1]
    pcl_depth = camera.get_world_to_view_transform().transform_points(pcl)[..., -1]

    depth_and_idx = torch.cat(
        (
            depth,
            torch.arange(H * W).view(1, 1, H, W).expand(ba, 1, H, W).type_as(depth),
        ),
        dim=1,
    )

    depth_and_idx_sampled = Fu.grid_sample(
        depth_and_idx, -pcl_proj[:, None], mode="nearest"
    )[:, :, 0].view(ba, 2, -1)

    depth_sampled, idx_sampled = depth_and_idx_sampled.split([1, 1], dim=1)
    df = (depth_sampled[:, 0] - pcl_depth).abs()

    # the threshold is a sigma-multiple of the standard deviation of the depth
    mu = wmean(depth.view(ba, -1, 1), mask.view(ba, -1)).view(ba, 1)
    std = (
        wmean((depth.view(ba, -1) - mu).view(ba, -1, 1) ** 2, mask.view(ba, -1))
        .clamp(1e-4)
        .sqrt()
        .view(ba, -1)
    )
    good_df_thr = std * sigma
    good_depth = (df <= good_df_thr).float() * pcl_mask

    perc_kept = good_depth.sum(dim=1) / pcl_mask.sum(dim=1).clamp(1)
    # print(f'Kept {100.0 * perc_kept.mean():1.3f} % points')

    good_depth_raster = torch.zeros_like(depth).view(ba, -1)
    # pyre-ignore[16]: scatter_add_
    good_depth_raster.scatter_add_(1, torch.round(idx_sampled[:, 0]).long(), good_depth)

    good_depth_mask = (good_depth_raster.view(ba, 1, H, W) > 0).float()

    # if float(torch.rand(1)) > 0.95:
    #     depth_ok = depth * good_depth_mask

    #     # visualize
    #     visdom_env = 'depth_cleanup_dbg'
    #     from visdom import Visdom
    #     # from tools.vis_utils import make_depth_image
    #     from pytorch3d.vis.plotly_vis import plot_scene
    #     viz = Visdom()

    #     show_pcls = {
    #         'pointclouds': point_cloud,
    #     }
    #     for d, nm in zip(
    #         (depth, depth_ok),
    #         ('pointclouds_unproj', 'pointclouds_unproj_ok'),
    #     ):
    #         pointclouds_unproj = get_rgbd_point_cloud(
    #             camera, image, d,
    #         )
    #         if int(pointclouds_unproj.num_points_per_cloud()) > 0:
    #             show_pcls[nm] = pointclouds_unproj

    #     scene_dict = {'1': {
    #         **show_pcls,
    #         'cameras': camera,
    #     }}
    #     scene = plot_scene(
    #         scene_dict,
    #         pointcloud_max_points=5000,
    #         pointcloud_marker_size=1.5,
    #         camera_scale=1.0,
    #     )
    #     viz.plotlyplot(scene, env=visdom_env, win='scene')

    #     # depth_image_ok = make_depth_image(depths_ok, masks)
    #     # viz.images(depth_image_ok, env=visdom_env, win='depth_ok')
    #     # depth_image = make_depth_image(depths, masks)
    #     # viz.images(depth_image, env=visdom_env, win='depth')
    #     # # viz.images(rgb_rendered, env=visdom_env, win='images_render')
    #     # viz.images(images, env=visdom_env, win='images')
    #     import pdb; pdb.set_trace()

    return good_depth_mask