Exemple #1
0
 def loss_fn(variables):
     residual = model_utils.viewdir_fn(model, variables, rgb_features,
                                       directions, scene_params)
     final_rgb = jnp.minimum(1.0,
                             rgb_features[Ellipsis, 0:3] + residual)
     loss = ((final_rgb - ref[Ellipsis, :3])**2).mean()
     return loss
Exemple #2
0
 def pmap_eval_fn(rgb_and_feature_chunk, direction_chunk):
   """We need an inner function as only JAX types can be passed to a pmap."""
   residual = model_utils.viewdir_fn(viewdir_mlp_model, viewdir_mlp_params,
                                     rgb_and_feature_chunk, direction_chunk,
                                     scene_params)
   output = jnp.minimum(1.0, rgb_and_feature_chunk[Ellipsis, 0:3] + residual)
   return jax.lax.all_gather(output, axis_name="batch")
def post_process_render(viewdir_mlp, viewdir_mlp_params, rgb, alpha, h, w,
                        focal, camtoworld, scene_params):
    """Post-processes a SNeRG render (background, view-dependence MLP).

  Composites the render onto the desired background color, then evaluates
  the view-dependence MLP for each pixel, and adds the specular residual.

  Args:
    viewdir_mlp: A nerf.model_utils.MLP that predicts the per-ray view-dependent
      residual color.
    viewdir_mlp_params: A dict containing the MLP parameters for the per-sample
      view-dependence MLP.
    rgb: A [H, W, 7] tensor containing the RGB and features accumulated at each
      pixel.
    alpha: A [H, W, 1] tensor containing the alpha accumulated at each pixel.
    h: The image height (pixels).
    w: The image width (pixels).
    focal: The image focal length (pixels).
    camtoworld: A numpy array of shape [4, 4] containing the camera-to-world
      transformation matrix for the camera.
    scene_params: A dict for scene specific params (bbox, rotation, resolution).

  Returns:
    A list containing post-processed images in the following order:
    the final output image (output_rgb), the alpha channel (alpha), the
    diffuse-only rgb image (rgb), the accumulated feature channels (features),
    and the specular residual from the view-dependence MLP (residual).
  """
    if scene_params['white_bkgd']:
        rgb[Ellipsis, 0:3] = np.ones_like(
            rgb[Ellipsis, 0:3]) * (1.0 - alpha) + rgb[Ellipsis, 0:3]

    features = rgb[Ellipsis, 3:scene_params['_channels']]
    rgb = rgb[Ellipsis, 0:3]

    rgb_features = np.concatenate([rgb, features], -1)
    _, _, viewdirs = datasets.rays_from_camera(
        scene_params['_use_pixel_centers'], h, w, focal,
        np.expand_dims(camtoworld, 0))
    viewdirs = viewdirs.reshape((rgb.shape[0], rgb.shape[1], 3))

    residual = model_utils.viewdir_fn(viewdir_mlp, viewdir_mlp_params,
                                      rgb_features, viewdirs, scene_params)
    output_rgb = np.minimum(1.0, rgb + residual)

    return_list = [output_rgb, alpha, rgb, features, residual]

    return return_list