Esempio n. 1
0
def integrate_visibility_from_image(h, w, focal, camtoworld, alpha_grid,
                                    visibility_grid, scene_params,
                                    grid_params):
    """Marks the voxels which are visible from the a given camera.

  A convenient wrapper function around integrate_visibility_from_rays.

  Args:
    h: The image height (pixels).
    w: The image width (pixels).
    focal: The image focal length (pixels).
    camtoworld: A numpy array of shape [4, 4] containing the camera-to-world
      transformation matrix for the camera.
    alpha_grid: A [cW, cH, cD, 1] numpy array for the alpha values in the
       low-res culling grid.
    visibility_grid: A [cW, cH, cD, 1] numpy array for the visibility values in
      the low-res culling grid. Note that this function will be adding
      visibility values into this grid.
    scene_params: A dict for scene specific params (bbox, rotation, resolution).
    grid_params: A dict with parameters describing the high-res voxel grid which
      the atlas is representing.
  """
    origins, directions, _ = datasets.rays_from_camera(
        scene_params['_use_pixel_centers'], h, w, focal,
        np.expand_dims(camtoworld, 0))
    if scene_params['ndc']:
        origins, directions = datasets.convert_to_ndc(origins, directions,
                                                      focal, w, h)

    integrate_visibility_from_rays(origins.reshape(-1, 3),
                                   directions.reshape(-1, 3), alpha_grid,
                                   visibility_grid, scene_params, grid_params)
Esempio n. 2
0
def post_process_render(viewdir_mlp, viewdir_mlp_params, rgb, alpha, h, w,
                        focal, camtoworld, scene_params):
    """Post-processes a SNeRG render (background, view-dependence MLP).

  Composites the render onto the desired background color, then evaluates
  the view-dependence MLP for each pixel, and adds the specular residual.

  Args:
    viewdir_mlp: A nerf.model_utils.MLP that predicts the per-ray view-dependent
      residual color.
    viewdir_mlp_params: A dict containing the MLP parameters for the per-sample
      view-dependence MLP.
    rgb: A [H, W, 7] tensor containing the RGB and features accumulated at each
      pixel.
    alpha: A [H, W, 1] tensor containing the alpha accumulated at each pixel.
    h: The image height (pixels).
    w: The image width (pixels).
    focal: The image focal length (pixels).
    camtoworld: A numpy array of shape [4, 4] containing the camera-to-world
      transformation matrix for the camera.
    scene_params: A dict for scene specific params (bbox, rotation, resolution).

  Returns:
    A list containing post-processed images in the following order:
    the final output image (output_rgb), the alpha channel (alpha), the
    diffuse-only rgb image (rgb), the accumulated feature channels (features),
    and the specular residual from the view-dependence MLP (residual).
  """
    if scene_params['white_bkgd']:
        rgb[Ellipsis, 0:3] = np.ones_like(
            rgb[Ellipsis, 0:3]) * (1.0 - alpha) + rgb[Ellipsis, 0:3]

    features = rgb[Ellipsis, 3:scene_params['_channels']]
    rgb = rgb[Ellipsis, 0:3]

    rgb_features = np.concatenate([rgb, features], -1)
    _, _, viewdirs = datasets.rays_from_camera(
        scene_params['_use_pixel_centers'], h, w, focal,
        np.expand_dims(camtoworld, 0))
    viewdirs = viewdirs.reshape((rgb.shape[0], rgb.shape[1], 3))

    residual = model_utils.viewdir_fn(viewdir_mlp, viewdir_mlp_params,
                                      rgb_features, viewdirs, scene_params)
    output_rgb = np.minimum(1.0, rgb + residual)

    return_list = [output_rgb, alpha, rgb, features, residual]

    return return_list
Esempio n. 3
0
def atlas_raymarch_image_tf(h, w, focal, camtoworld, atlas_t,
                            atlas_block_indices_t, atlas_params, scene_params,
                            grid_params):
    """Fast ray marching through a SNeRG scene for an image.

  A convenient wrapper function around atlas_raymarch_rays_parallel_tf.

  Args:
    h: The image height (pixels).
    w: The image width (pixels).
    focal: The image focal length (pixels).
    camtoworld: A numpy array of shape [4, 4] containing the camera-to-world
      transformation matrix for the camera.
    atlas_t: A tensorflow tensor containing the texture atlas.
    atlas_block_indices_t: A tensorflow tensor containing the indirection grid.
    atlas_params: A dict with params for building and rendering with
      the 3D texture atlas.
    scene_params: A dict for scene specific params (bbox, rotation, resolution).
    grid_params: A dict with parameters describing the high-res voxel grid which
      the atlas is representing.

  Returns:
    rgb: A [h, w, C] np.array with the colors and features accumulated for each
      pixel.
    alpha: A [h, w, 1 ] np.array with the alpha value accumuated for each pixel.
  """
    origins, directions, _ = datasets.rays_from_camera(
        scene_params['_use_pixel_centers'], h, w, focal,
        np.expand_dims(camtoworld, 0))
    if scene_params['ndc']:
        origins, directions = datasets.convert_to_ndc(origins, directions,
                                                      focal, w, h)

    return atlas_raymarch_rays_parallel_tf(h, w, origins, directions, atlas_t,
                                           atlas_block_indices_t, atlas_params,
                                           scene_params, grid_params)
Esempio n. 4
0
def build_sharded_dataset_for_view_dependence(source_dataset, atlas_t,
                                              atlas_block_indices_t,
                                              atlas_params, scene_params,
                                              grid_params):
    """Builds a dataset that we can run the view-dependence MLP on.

  We ray march through a baked SNeRG model to generate images with RGB colors
  and features. These serve as the input for the view-dependence MLP which adds
  back the effects such as highlights.

  To make use of multi-host parallelism provided by JAX, this function shards
  the dataset, so that each host contains only a slice of the data.

  Args:
    source_dataset: The nerf.datasets.Dataset we should compute data for.
    atlas_t: A tensorflow tensor containing the texture atlas.
    atlas_block_indices_t: A tensorflow tensor containing the indirection grid.
    atlas_params: A dict with params for building and rendering with
      the 3D texture atlas.
    scene_params: A dict for scene specific params (bbox, rotation, resolution).
    grid_params: A dict with parameters describing the high-res voxel grid which
      the atlas is representing.

  Returns:
    rgb_data: The RGB (+ features) input data, stored as an
      (N/NUM_HOSTS, NUM_LOCAL_DEVICES, H, W, 7) numpy array.
    alpha_data: The alpha channel of the input data, stored as an
      (N/NUM_HOSTS, NUM_LOCAL_DEVICES, H, W, 1) numpy array.
    direction_data: The direction vectors for the input data, stored as an
      (N/NUM_HOSTS, NUM_LOCAL_DEVICES, H, W, 3) numpy array.
    ref_data: The reference RGB colors for each input data sample, stored as an
      (N/NUM_HOSTS, NUM_LOCAL_DEVICES, H, W, 3) numpy array.
  """

    num_hosts = jax.host_count()
    num_local_devices = jax.local_device_count()
    host_id = jax.host_id()
    num_images = source_dataset.camtoworlds.shape[0]
    num_batches = math.ceil(num_images / num_hosts)
    num_batches = num_local_devices * math.ceil(
        num_batches / num_local_devices)

    rgb_list = []
    alpha_list = []
    viewdir_list = []
    ref_list = []
    for i in range(num_batches):
        base_index = i * num_hosts
        dataset_index = base_index + host_id

        rgb = np.zeros(
            (source_dataset.h, source_dataset.w, scene_params["_channels"]),
            dtype=np.float32)
        alpha = np.zeros((source_dataset.h, source_dataset.w, 1),
                         dtype=np.float32)
        viewdirs = np.zeros((source_dataset.h, source_dataset.w, 3),
                            dtype=np.float32)

        if dataset_index < num_images:
            rgb, alpha = rendering.atlas_raymarch_image_tf(
                source_dataset.h, source_dataset.w, source_dataset.focal,
                source_dataset.camtoworlds[dataset_index], atlas_t,
                atlas_block_indices_t, atlas_params, scene_params, grid_params)
            _, _, viewdirs = datasets.rays_from_camera(
                scene_params["_use_pixel_centers"], source_dataset.h,
                source_dataset.w, source_dataset.focal,
                np.expand_dims(source_dataset.camtoworlds[dataset_index], 0))

        np_rgb = np.array(rgb).reshape(
            (source_dataset.h, source_dataset.w, scene_params["_channels"]))
        np_alpha = np.array(alpha).reshape(
            (source_dataset.h, source_dataset.w, 1))
        np_viewdirs = viewdirs.reshape((np_rgb.shape[0], np_rgb.shape[1], 3))
        if scene_params["white_bkgd"]:
            np_rgb[Ellipsis, 0:3] = np.ones_like(np_rgb[Ellipsis, 0:3]) * (
                1.0 - np_alpha) + np_rgb[Ellipsis, 0:3]

        rgb_list.append(np_rgb)
        alpha_list.append(np_alpha)
        viewdir_list.append(np_viewdirs)
        ref_list.append(source_dataset.images[dataset_index % num_images])

    rgb_data = np.stack(rgb_list, 0).reshape(
        (-1, num_local_devices, source_dataset.h, source_dataset.w,
         scene_params["_channels"]))
    alpha_data = np.stack(alpha_list, 0).reshape(
        (-1, num_local_devices, source_dataset.h, source_dataset.w, 1))
    viewdir_data = np.stack(viewdir_list, 0).reshape(
        (-1, num_local_devices, source_dataset.h, source_dataset.w, 3))
    ref_data = np.stack(ref_list, 0).reshape(
        (-1, num_local_devices, source_dataset.h, source_dataset.w, 3))

    return rgb_data, alpha_data, viewdir_data, ref_data