def loss_fn(variables): residual = model_utils.viewdir_fn(model, variables, rgb_features, directions, scene_params) final_rgb = jnp.minimum(1.0, rgb_features[Ellipsis, 0:3] + residual) loss = ((final_rgb - ref[Ellipsis, :3])**2).mean() return loss
def pmap_eval_fn(rgb_and_feature_chunk, direction_chunk): """We need an inner function as only JAX types can be passed to a pmap.""" residual = model_utils.viewdir_fn(viewdir_mlp_model, viewdir_mlp_params, rgb_and_feature_chunk, direction_chunk, scene_params) output = jnp.minimum(1.0, rgb_and_feature_chunk[Ellipsis, 0:3] + residual) return jax.lax.all_gather(output, axis_name="batch")
def post_process_render(viewdir_mlp, viewdir_mlp_params, rgb, alpha, h, w, focal, camtoworld, scene_params): """Post-processes a SNeRG render (background, view-dependence MLP). Composites the render onto the desired background color, then evaluates the view-dependence MLP for each pixel, and adds the specular residual. Args: viewdir_mlp: A nerf.model_utils.MLP that predicts the per-ray view-dependent residual color. viewdir_mlp_params: A dict containing the MLP parameters for the per-sample view-dependence MLP. rgb: A [H, W, 7] tensor containing the RGB and features accumulated at each pixel. alpha: A [H, W, 1] tensor containing the alpha accumulated at each pixel. h: The image height (pixels). w: The image width (pixels). focal: The image focal length (pixels). camtoworld: A numpy array of shape [4, 4] containing the camera-to-world transformation matrix for the camera. scene_params: A dict for scene specific params (bbox, rotation, resolution). Returns: A list containing post-processed images in the following order: the final output image (output_rgb), the alpha channel (alpha), the diffuse-only rgb image (rgb), the accumulated feature channels (features), and the specular residual from the view-dependence MLP (residual). """ if scene_params['white_bkgd']: rgb[Ellipsis, 0:3] = np.ones_like( rgb[Ellipsis, 0:3]) * (1.0 - alpha) + rgb[Ellipsis, 0:3] features = rgb[Ellipsis, 3:scene_params['_channels']] rgb = rgb[Ellipsis, 0:3] rgb_features = np.concatenate([rgb, features], -1) _, _, viewdirs = datasets.rays_from_camera( scene_params['_use_pixel_centers'], h, w, focal, np.expand_dims(camtoworld, 0)) viewdirs = viewdirs.reshape((rgb.shape[0], rgb.shape[1], 3)) residual = model_utils.viewdir_fn(viewdir_mlp, viewdir_mlp_params, rgb_features, viewdirs, scene_params) output_rgb = np.minimum(1.0, rgb + residual) return_list = [output_rgb, alpha, rgb, features, residual] return return_list