def pixels_to_world(self,
                        pixels,
                        cameras,
                        c,
                        it=None,
                        sampling_accuracy=None):
        ''' Projects pixels to the world coordinate system.

        Args:
            pixels (tensor): sampled pixels in range [-1, 1]
            c (tensor): latent conditioned code c
            it (int): training iteration (used for ray sampling scheduler)
            sampling_accuracy (tuple): if not None, this overwrites the default
                sampling accuracy ([128, 129])
        '''
        batch_size, n_points, _ = pixels.shape
        pixels_world = image_points_to_world(pixels, cameras)
        camera_world = origin_to_world(n_points, cameras)
        ray_vector = (pixels_world - camera_world)

        d_hat, mask_pred, mask_zero_occupied = self.march_along_ray(
            camera_world, ray_vector, c, it, sampling_accuracy)
        # NOTE: d_i can be zero because call_depth_function set mask_zero_occupied points to 0,
        # but camera transform will outputs nan/infinite for zero depth points
        d_hat[mask_zero_occupied] = cameras.znear
        p_world_hat = camera_world + ray_vector * d_hat.unsqueeze(-1)
        return p_world_hat, mask_pred, mask_zero_occupied
Beispiel #2
0
    def pixels_to_world(self,
                        pixels,
                        camera_mat,
                        world_mat,
                        scale_mat,
                        c,
                        it=None,
                        sampling_accuracy=None):
        ''' Projects pixels to the world coordinate system.

        Args:
            pixels (tensor): sampled pixels in range [-1, 1]
            camera_mat (tensor): camera matrices
            world_mat (tensor): world matrices
            scale_mat (tensor): scale matrices
            c (tensor): latent conditioned code c
            it (int): training iteration (used for ray sampling scheduler)
            sampling_accuracy (tuple): if not None, this overwrites the default
                sampling accuracy ([128, 129])
        '''
        batch_size, n_points, _ = pixels.shape
        pixels_world = image_points_to_world(pixels, camera_mat, world_mat,
                                             scale_mat)
        camera_world = origin_to_world(n_points, camera_mat, world_mat,
                                       scale_mat)
        ray_vector = (pixels_world - camera_world)

        d_hat, mask_pred, mask_zero_occupied = self.march_along_ray(
            camera_world, ray_vector, c, it, sampling_accuracy)
        p_world_hat = camera_world + ray_vector * d_hat.unsqueeze(-1)
        return p_world_hat, mask_pred, mask_zero_occupied
Beispiel #3
0
    def pixels_to_world(self,
                        pixels,
                        camera_mat,
                        world_mat,
                        scale_mat,
                        c,
                        return_deform=False):
        ''' Projects pixels to the world coordinate system.

        '''
        device = self._device
        batch_size, n_p, _ = pixels.shape

        pixels_world = image_points_to_world(pixels, camera_mat, world_mat,
                                             scale_mat)
        camera_world = origin_to_world(n_p, camera_mat, world_mat, scale_mat)
        ray_vector = (pixels_world - camera_world)

        location, index_ray, index_tri = self.mesh.ray.intersects_location(
            ray_origins=camera_world.detach().cpu().numpy().reshape(
                batch_size * n_p, 3),
            ray_directions=ray_vector.detach().cpu().numpy().reshape(
                batch_size * n_p, 3),
            multiple_hits=False)
        mask = np.zeros(batch_size * n_p)
        mask[index_ray] = 1

        tris = self.mesh.faces[index_tri]  # H*3
        vert = self.vert[tris]  # H*3*3
        feat = self.feat[tris]  # H*3*F

        location = torch.from_numpy(location).to(device)
        coords = calculate_berycentric_coords(location, vert)
        weighted_feat = coords[:, :, None] * feat

        coords_encoded = self.position_encoding(coords, self.B)
        out = self.decoder(torch.cat([coords_encoded, weighted_feat], dim=-1),
                           c=c)

        p_world = camera_world.clone().view(batch_size * n_p, 3)
        c_world = torch.zeros_like(p_world)
        p_world[mask] = location + out[:, 0:3]
        c_world[mask] = torch.sigmoid(out[:, 3:6])

        p_world = p_world.view(batch_size, n_p, 3)
        c_world = c_world.view(batch_size, n_p, 3)
        mask = mask.reshape(batch_size, n_p)

        if return_deform:
            return p_world, c_world, mask, coords, out[:, 0:3], feat
        else:
            return p_world, c_world, mask