예제 #1
0
    def test_invalid_inputs_shapes(self, device="cuda:0"):
        with self.assertRaisesRegex(
            ValueError, "input can only be 2-dimensional."
        ):
            values = torch.rand((100, 50, 2), device=device)
            first_idxs = torch.tensor([0, 80], dtype=torch.int64, device=device)
            packed_to_padded(values, first_idxs, 100)

        with self.assertRaisesRegex(
            ValueError, "input can only be 3-dimensional."
        ):
            values = torch.rand((100,), device=device)
            first_idxs = torch.tensor([0, 80], dtype=torch.int64, device=device)
            padded_to_packed(values, first_idxs, 20)

        with self.assertRaisesRegex(
            ValueError, "input can only be 3-dimensional."
        ):
            values = torch.rand((100, 50, 2, 2), device=device)
            first_idxs = torch.tensor([0, 80], dtype=torch.int64, device=device)
            padded_to_packed(values, first_idxs, 20)
예제 #2
0
    def _test_packed_to_padded_helper(self, D, device):
        """
        Check the results from packed_to_padded and PyTorch implementations
        are the same.
        """
        meshes = self.init_meshes(16, 100, 300, device=device)
        faces = meshes.faces_packed()
        mesh_to_faces_packed_first_idx = meshes.mesh_to_faces_packed_first_idx(
        )
        max_faces = meshes.num_faces_per_mesh().max().item()

        if D == 0:
            values = torch.rand((faces.shape[0], ),
                                device=device,
                                requires_grad=True)
        else:
            values = torch.rand((faces.shape[0], D),
                                device=device,
                                requires_grad=True)
        values_torch = values.detach().clone()
        values_torch.requires_grad = True
        values_padded = packed_to_padded(values,
                                         mesh_to_faces_packed_first_idx,
                                         max_faces)
        values_padded_torch = TestPackedToPadded.packed_to_padded_python(
            values_torch, mesh_to_faces_packed_first_idx, max_faces, device)
        # check forward
        self.assertClose(values_padded, values_padded_torch)

        # check backward
        if D == 0:
            grad_inputs = torch.rand((len(meshes), max_faces), device=device)
        else:
            grad_inputs = torch.rand((len(meshes), max_faces, D),
                                     device=device)
        values_padded.backward(grad_inputs)
        grad_outputs = values.grad
        values_padded_torch.backward(grad_inputs)
        grad_outputs_torch1 = values_torch.grad
        grad_outputs_torch2 = TestPackedToPadded.padded_to_packed_python(
            grad_inputs,
            mesh_to_faces_packed_first_idx,
            values.size(0),
            device=device,
        )
        self.assertClose(grad_outputs, grad_outputs_torch1)
        self.assertClose(grad_outputs, grad_outputs_torch2)
예제 #3
0
def reduce_mask_padded(values: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
    """
    Remove the invalid values in a padded tensor form as much as possible
    Args:
        values (tensor(number)): (N, P, ..., C)
        mask   (tensor(bool)): (N, P, ...) bool values
    Returns:
        reduced (tensor(number)): (N, Pmax, ..., C) Pmax is the maximum number of
            True values in the batches of mask.
    """
    from pytorch3d.ops import packed_to_padded
    batch_size = values.shape[0]
    value_packed = values[mask]
    first_idx = torch.zeros(
        (values.shape[0],), device=values.device, dtype=torch.long)
    num_true_in_batch = mask.view(batch_size, -1).sum(dim=1)
    first_idx[1:] = num_true_in_batch.cumsum(dim=0)[:-1]
    value_padded = packed_to_padded(
        value_packed, first_idx, num_true_in_batch.max().item())
    return value_padded
예제 #4
0
 def out():
     packed_to_padded(values, mesh_to_faces_packed_first_idx, max_faces)
     torch.cuda.synchronize()
예제 #5
0
파일: rasterizer.py 프로젝트: yifita/DSS
    def backward(ctx, idx_grad, zbuf_grad, qvalue_grad, occ_grad):
        # idx_grad and zbuf_grad are None (unless maybe we make weights depend on z? i.e. volumetric splatting)

        grad_radii = None
        grad_cloud_to_packed_first_idx = None
        grad_num_points_per_cloud = None
        grad_cutoff_thres = None
        grad_depth_merging_thres = None
        grad_image_size = None
        grad_points_per_pixel = None
        grad_bin_size = None
        grad_max_points_per_bin = None
        grad_radii_s = None
        grad_backward_rbf = None

        grads = (grad_cutoff_thres, grad_radii, grad_cloud_to_packed_first_idx,
                 grad_num_points_per_cloud, grad_depth_merging_thres,
                 grad_image_size, grad_points_per_pixel, grad_bin_size,
                 grad_max_points_per_bin, grad_radii_s, grad_backward_rbf)

        radii_s = ctx.radii_backward_scaler

        # either use OccRBFBackward or use OccBackward
        pts_screen, ellipse_param, cutoff_threshold, radii, idx, zbuf0, \
            cloud_to_packed_first_idx, num_points_per_cloud, \
            = ctx.saved_tensors
        depth_merging_threshold = ctx.depth_merging_threshold

        backward_occ_fast = True
        if not backward_occ_fast:
            device = pts_screen.device
            grads_input_xy = pts_screen.new_zeros((pts_screen.shape[0], 2))
            grads_input_z = pts_screen.new_zeros((pts_screen.shape[0], 1))
            mask = (idx[..., 0] >= 0).bool()  # float
            pts_visibility = torch.full((pts_screen.shape[0], ),
                                        False,
                                        dtype=torch.bool,
                                        device=pts_screen.device)
            # all rendered points (indices in packed points)
            visible_idx = idx[mask].unique().long().view(-1)
            visible_idx = visible_idx[visible_idx >= 0]
            pts_visibility[visible_idx] = True
            num_points_per_cloud = torch.stack([
                x.sum() for x in torch.split(
                    pts_visibility, num_points_per_cloud.tolist(), dim=0)
            ])
            cloud_to_packed_first_idx = num_points_2_cloud_to_packed_first_idx(
                num_points_per_cloud)

            pts_screen = pts_screen[pts_visibility]
            radii = radii[pts_visibility]
            grad_visible = _C._splat_points_occ_backward(
                pts_screen, radii, occ_grad, cloud_to_packed_first_idx,
                num_points_per_cloud, radii_s, depth_merging_threshold)
            if torch.isnan(grad_visible).any(
            ) or not torch.isfinite(grad_visible).all():
                print('invalid grad_visible')
            assert (pts_visibility.sum() == grad_visible.shape[0])
            grads_input_xy[pts_visibility] = grad_visible
            _C._backward_zbuf(idx, zbuf_grad, grads_input_z)
            # TODO necessary to concatenate
            grads_input = torch.cat([grads_input_xy, grads_input_z], dim=-1)
        else:
            """
            We only care about rasterized points (visible points)
            1. Filter [P,*] data to [P_visible,*] data
            2. Fast backward cuda
                2a. call FRNN insertion
                2b. count_sort
            """
            device = pts_screen.device
            mask = (idx[..., 0] >= 0).bool()  # float
            pts_visibility = torch.full((pts_screen.shape[0], ),
                                        False,
                                        dtype=torch.bool,
                                        device=pts_screen.device)
            # all rendered points (indices in packed points)
            visible_idx = idx[mask].unique().long().view(-1)
            visible_idx = visible_idx[visible_idx >= 0]
            pts_visibility[visible_idx] = True
            num_points_per_cloud = torch.stack([
                x.sum() for x in torch.split(
                    pts_visibility, num_points_per_cloud.tolist(), dim=0)
            ])
            cloud_to_packed_first_idx = num_points_2_cloud_to_packed_first_idx(
                num_points_per_cloud)

            pts_screen_visible = pts_screen[pts_visibility]
            radii_visible = radii[pts_visibility]

            #####################################
            #  2a. call FRNN insertion
            #####################################
            N = num_points_per_cloud.shape[0]
            P = pts_screen_visible.shape[0]
            assert (num_points_per_cloud.sum().item() == P)
            # from frnn.frnn import GRID_PARAMS_SIZE, MAX_RES, prefix_sum_cuda
            # imported from
            from prefix_sum import prefix_sum_cuda
            GRID_2D_PARAMS_SIZE = 6
            GRID_2D_MAX_RES = 1024
            GRID_2D_DELTA = 2
            GRID_2D_TOTAL = 5
            RADIUS_CELL_RATIO = 2
            # first convert to padded
            max_P = num_points_per_cloud.max().item()
            pts_padded = ops3d.packed_to_padded(pts_screen_visible,
                                                cloud_to_packed_first_idx,
                                                max_P)
            radii_padded = ops3d.packed_to_padded(radii_visible,
                                                  cloud_to_packed_first_idx,
                                                  max_P)
            # determine search radius as max(radii)*radii_s
            search_radius = torch.tensor([
                radii_padded[i, :num_points_per_cloud[i]].median() * radii_s
                for i in range(N)
            ],
                                         dtype=torch.float,
                                         device=device)
            # create grid from scratch
            # setup grid params
            grid_params_cuda = torch.zeros((N, GRID_2D_PARAMS_SIZE),
                                           dtype=torch.float,
                                           device=pts_padded.device)
            G = -1
            pts_padded_2D = pts_padded[:, :, :2].clone().contiguous()
            for i in range(N):
                # 0-2 grid_min; 3 grid_delta; 4-6 grid_res; 7 grid_total
                grid_min = pts_padded_2D[i, :num_points_per_cloud[i]].min(
                    dim=0)[0]
                grid_max = pts_padded_2D[i, :num_points_per_cloud[i]].max(
                    dim=0)[0]
                grid_params_cuda[i, :GRID_2D_DELTA] = grid_min
                grid_size = grid_max - grid_min
                cell_size = search_radius[i].item() / RADIUS_CELL_RATIO
                if cell_size < grid_size.min() / GRID_2D_MAX_RES:
                    cell_size = grid_size.min() / GRID_2D_MAX_RES
                grid_params_cuda[i, GRID_2D_DELTA] = 1 / cell_size
                grid_params_cuda[i, GRID_2D_DELTA +
                                 1:GRID_2D_TOTAL] = torch.floor(
                                     grid_size / cell_size) + 1
                grid_params_cuda[i, GRID_2D_TOTAL] = torch.prod(
                    grid_params_cuda[i, GRID_2D_DELTA + 1:GRID_2D_TOTAL])
                if G < grid_params_cuda[i, GRID_2D_TOTAL]:
                    G = int(grid_params_cuda[i, GRID_2D_TOTAL].item())

            # insert points into the grid
            pc_grid_cnt = torch.zeros((N, G), dtype=torch.int, device=device)
            pc_grid_cell = torch.full((N, max_P),
                                      -1,
                                      dtype=torch.int,
                                      device=device)
            pc_grid_idx = torch.full((N, max_P),
                                     -1,
                                     dtype=torch.int,
                                     device=device)
            frnn._C.insert_points_cuda(pts_padded_2D, num_points_per_cloud,
                                       grid_params_cuda, pc_grid_cnt,
                                       pc_grid_cell, pc_grid_idx, G)

            # use prefix_sum from Matt Dean
            grid_params = grid_params_cuda.cpu()
            pc_grid_off = torch.full((N, G), 0, dtype=torch.int, device=device)
            for i in range(N):
                prefix_sum_cuda(pc_grid_cnt[i], grid_params[i, GRID_2D_TOTAL],
                                pc_grid_off[i])

            # sort points according to their grid positions and insertion orders
            # sort based on x, y first. Then we will use points_sorted_idxs to recover the points_sorted with Z
            points_sorted = torch.zeros((N, max_P, 2),
                                        dtype=torch.float,
                                        device=device)
            points_sorted_idxs = torch.full((N, max_P),
                                            -1,
                                            dtype=torch.int,
                                            device=device)
            frnn._C.counting_sort_cuda(
                pts_padded_2D,
                num_points_per_cloud,
                pc_grid_cell,
                pc_grid_idx,
                pc_grid_off,
                points_sorted,  # (N,P,2)
                points_sorted_idxs  # (N,P)
            )
            new_points_sorted = torch.zeros_like(pts_padded)
            for i in range(N):
                points_sorted_idxs_i = points_sorted_idxs[
                    i, :num_points_per_cloud[i]].long().unsqueeze(1).expand(
                        -1, 3)
                new_points_sorted[i, :num_points_per_cloud[i]] = torch.gather(
                    pts_padded[i], 0, points_sorted_idxs_i)
                # print(points_sorted[i, :10])
                # print(new_points_sorted[i, :10])
            # new_points_sorted = torch.gather(pts_padded, 1, points_sorted_idxs.long().unsqueeze(2).expand(-1, -1, 3))

            assert (new_points_sorted is not None and pc_grid_off is not None
                    and points_sorted_idxs is not None
                    and grid_params_cuda is not None)
            # convert sorted_points and sorted_points_idxs to packed (P, )
            points_sorted = ops3d.padded_to_packed(new_points_sorted,
                                                   cloud_to_packed_first_idx,
                                                   P)
            # padded_to_packed only supports torch.float32...
            shifted_points_sorted_idxs = points_sorted_idxs + cloud_to_packed_first_idx.float(
            ).unsqueeze(1)
            points_sorted_idxs = ops3d.padded_to_packed(
                shifted_points_sorted_idxs, cloud_to_packed_first_idx, P)
            points_sorted_idxs_2D = points_sorted_idxs.long().unsqueeze(
                1).expand(-1, 2)
            radii_sorted = torch.gather(radii_visible, 0,
                                        points_sorted_idxs_2D)
            pc_grid_off += cloud_to_packed_first_idx.unsqueeze(1)
            grad_sorted = _C._splat_points_occ_fast_cuda_backward(
                points_sorted, radii_sorted, search_radius, occ_grad,
                num_points_per_cloud, cloud_to_packed_first_idx, pc_grid_off,
                grid_params_cuda)
            # grad_sorted_slow = _C._splat_points_occ_backward(points_sorted, radii_sorted,
            #                                             occ_grad, cloud_to_packed_first_idx, num_points_per_cloud,
            #                                             radii_s, depth_merging_threshold)
            # breakpoint()
            # points_sorted_idxs_3D = points_sorted_idxs.long().unsqueeze(1).expand(-1, 3)
            # print(points_sorted_idxs_3D.max(), grad_sorted.shape[0])
            grad_visible = torch.zeros_like(grad_sorted).scatter_(
                0, points_sorted_idxs_2D, grad_sorted)
            # grad_visible_slow = _C._splat_points_occ_backward(pts_screen[pts_visibility], radii[pts_visibility],
            #                                             occ_grad, cloud_to_packed_first_idx, num_points_per_cloud,
            #                                             radii_s, depth_merging_threshold)
            # breakpoint()
            if torch.isnan(grad_visible).any(
            ) or not torch.isfinite(grad_visible).all():
                print('invalid grad_visible')
            assert (pts_visibility.sum() == grad_visible.shape[0])
            grads_input_xy = pts_screen.new_zeros(pts_screen.shape[0], 2)
            grads_input_z = pts_screen.new_zeros(pts_screen.shape[0], 1)
            # print("1")
            grads_input_xy[pts_visibility] = grad_visible
            _C._backward_zbuf(idx, zbuf_grad, grads_input_z)
            grads_input = torch.cat([grads_input_xy, grads_input_z], dim=-1)
            # print("2")

        pts_grad = grads_input

        return (pts_grad, None) + grads
예제 #6
0
파일: rasterizer.py 프로젝트: yifita/DSS
    def forward(self,
                point_clouds,
                point_clouds_filter=None,
                **kwargs) -> PointFragments:
        """
        Args:
            point_clouds (Pointclouds3D): a set of point clouds with coordinates.
            per_point_info (dict):
                radii_packed: (N,2) axis-aligned radii in packed form
                ellipse_params_packed: (N,3) ellipse parameters in packed form
        Returns:
            PointFragments: Rasterization outputs as a named tuple.
        """
        raster_settings = kwargs.get("raster_settings", self.raster_settings)
        cameras = kwargs.get('cameras', self.cameras)
        max_P = point_clouds.num_points_per_cloud().max().item()
        total_P = point_clouds.num_points_per_cloud().sum().item()

        point_clouds_filtered, mask_filtered = self.filter_renderable(
            point_clouds, point_clouds_filter, **kwargs)

        if point_clouds_filtered.isempty():
            return self._empty_fragments(cameras.R.shape[0], **kwargs)

        # compute per-point features for elliptical gaussian weights
        with torch.autograd.no_grad():
            per_point_info = self._get_per_point_info(point_clouds_filtered,
                                                      **kwargs)

        _tmp = point_clouds_filtered.points_padded()
        if _tmp.requires_grad:
            _tmp.register_hook(lambda x: _check_grad(x, 'transform'))
        pcls_screen = self.transform(point_clouds_filtered, **kwargs)

        idx, zbuf, qvalue_map, occ_map = rasterize_elliptical_points(
            pcls_screen,
            per_point_info["ellipse_params"],
            per_point_info['cutoff_threshold'],
            per_point_info["radii"],
            depth_merging_threshold=raster_settings.depth_merging_threshold,
            image_size=raster_settings.image_size,
            points_per_pixel=raster_settings.points_per_pixel,
            bin_size=raster_settings.bin_size,
            max_points_per_bin=raster_settings.max_points_per_bin,
            radii_backward_scaler=raster_settings.radii_backward_scaler,
            clip_pts_grad=raster_settings.clip_pts_grad)

        # compute weight: scalar*exp(-0.5Q)
        frag_scaler = gather_with_neg_idx(per_point_info['scaler'], 0,
                                          idx.view(-1).long())
        frag_scaler = frag_scaler.view_as(qvalue_map)

        fragments = PointFragments(idx=idx,
                                   zbuf=zbuf,
                                   qvalue=qvalue_map,
                                   scaler=frag_scaler,
                                   occupancy=occ_map)

        # returns (P,) boolean mask for visibility
        visibility_mask = get_per_point_visibility_mask(
            point_clouds_filtered, fragments)
        mask_filtered[mask_filtered] = visibility_mask

        if point_clouds_filter is not None:
            # update point_clouds visibility filter
            # we use this information in projection loss
            # put all_depth_visibility_mask (num_active) to original_visibility_mask (P,)
            # transform to padded
            # original_visibility_mask = ops3d.packed_to_padded(
            #     valid_depth_mask.float(), first_idx, max_P).bool()
            # lixin
            original_visibility_mask = ops3d.packed_to_padded(
                mask_filtered.float(),
                point_clouds.cloud_to_packed_first_idx(), max_P).bool()
            point_clouds_filter.set_filter(visibility=original_visibility_mask)

        if kwargs.get('verbose', False):
            # use scatter to get per point info of the original
            original_per_point_info = {}
            for k in per_point_info:
                original_per_point_info[k] = per_point_info[k].new_zeros(
                    (total_P, ) + per_point_info[k].shape[1:])
                original_per_point_info[k][mask_filtered] = per_point_info[k]

            return fragments, point_clouds_filtered, original_per_point_info
        return fragments, point_clouds_filtered
예제 #7
0
    def test_dataset(self):
        # 1. rerender input point clouds / meshes using the saved camera_mat
        #    compare mask image with saved mask image
        # 2. backproject masked points to space with dense depth map,
        #    fuse all views and save
        batch_size = 1
        device = torch.device('cuda:0')

        data_dir = 'data/synthetic/cube_mesh'
        output_dir = os.path.join('tests', 'outputs', 'test_data')
        if not os.path.isdir(output_dir):
            os.makedirs(output_dir)

        # dataset
        dataset = MVRDataset(data_dir=data_dir,
                             load_dense_depth=True,
                             mode="train")
        data_loader = torch.utils.data.DataLoader(dataset,
                                                  batch_size=batch_size,
                                                  num_workers=0,
                                                  shuffle=False)
        meshes = load_objs_as_meshes([os.path.join(data_dir,
                                                   'mesh.obj')]).to(device)
        cams = dataset.get_cameras().to(device)
        image_size = imageio.imread(dataset.image_files[0]).shape[0]

        # initialize rasterizer, we check mask pngs only, so no need to create lights and shaders etc
        raster_settings = RasterizationSettings(
            image_size=image_size,
            blur_radius=0.0,
            faces_per_pixel=5,
            bin_size=
            None,  # this setting controls whether naive or coarse-to-fine rasterization is used
            max_faces_per_bin=None  # this setting is for coarse rasterization
        )
        rasterizer = MeshRasterizer(cameras=None,
                                    raster_settings=raster_settings)

        # render with loaded cameras positions and training tranformation functions
        pixel_world_all = []
        for idx, data in enumerate(data_loader):
            # get datas
            img = data.get('img.rgb').to(device)
            assert (img.min() >= 0 and img.max() <= 1
                    ), "Image must be a floating number between 0 and 1."
            mask_gt = data.get('img.mask').to(device).permute(0, 2, 3, 1)

            camera_mat = data['camera_mat'].to(device)

            cams.R, cams.T = decompose_to_R_and_t(camera_mat)
            cams._N = cams.R.shape[0]
            cams.to(device)
            self.assertTrue(
                torch.equal(cams.get_world_to_view_transform().get_matrix(),
                            camera_mat))

            # transform to view and rerender with non-rotated camera
            verts_padded = transform_to_camera_space(meshes.verts_padded(),
                                                     cams)
            meshes_in_view = meshes.offset_verts(
                -meshes.verts_packed() + padded_to_packed(
                    verts_padded, meshes.mesh_to_verts_packed_first_idx(),
                    meshes.verts_packed().shape[0]))

            fragments = rasterizer(meshes_in_view,
                                   cameras=dataset.get_cameras().to(device))

            # compare mask
            mask = fragments.pix_to_face[..., :1] >= 0
            imageio.imwrite(os.path.join(output_dir, "mask_%06d.png" % idx),
                            mask[0, ...].cpu().to(dtype=torch.uint8) * 255)
            # allow 5 pixels difference
            self.assertTrue(torch.sum(mask_gt != mask) < 5)

            # check dense maps
            # backproject points to the world pixel range (-1, 1)
            pixels = arange_pixels((image_size, image_size),
                                   batch_size)[1].to(device)

            depth_img = data.get('img.depth').to(device)
            # get the depth and mask at the sampled pixel position
            depth_gt = get_tensor_values(depth_img,
                                         pixels,
                                         squeeze_channel_dim=True)
            mask_gt = get_tensor_values(mask.permute(0, 3, 1, 2).float(),
                                        pixels,
                                        squeeze_channel_dim=True).bool()
            # get pixels and depth inside the masked area
            pixels_packed = pixels[mask_gt]
            depth_gt_packed = depth_gt[mask_gt]
            first_idx = torch.zeros((pixels.shape[0], ),
                                    device=device,
                                    dtype=torch.long)
            num_pts_in_mask = mask_gt.sum(dim=1)
            first_idx[1:] = num_pts_in_mask.cumsum(dim=0)[:-1]
            pixels_padded = packed_to_padded(pixels_packed, first_idx,
                                             num_pts_in_mask.max().item())
            depth_gt_padded = packed_to_padded(depth_gt_packed, first_idx,
                                               num_pts_in_mask.max().item())
            # backproject to world coordinates
            # contains nan and infinite values due to depth_gt_padded containing 0.0
            pixel_world_padded = transform_to_world(pixels_padded,
                                                    depth_gt_padded[..., None],
                                                    cams)
            # transform back to list, containing no padded values
            split_size = num_pts_in_mask[..., None].repeat(1, 2)
            split_size[:, 1] = 3
            pixel_world_list = padded_to_list(pixel_world_padded, split_size)
            pixel_world_all.extend(pixel_world_list)

            idx += 1
            if idx >= 10:
                break

        pixel_world_all = torch.cat(pixel_world_all, dim=0)
        mesh = trimesh.Trimesh(vertices=pixel_world_all.cpu(),
                               faces=None,
                               process=False)
        mesh.export(os.path.join(output_dir, 'pixel_to_world.ply'))