def integrate_visibility_from_image(h, w, focal, camtoworld, alpha_grid, visibility_grid, scene_params, grid_params): """Marks the voxels which are visible from the a given camera. A convenient wrapper function around integrate_visibility_from_rays. Args: h: The image height (pixels). w: The image width (pixels). focal: The image focal length (pixels). camtoworld: A numpy array of shape [4, 4] containing the camera-to-world transformation matrix for the camera. alpha_grid: A [cW, cH, cD, 1] numpy array for the alpha values in the low-res culling grid. visibility_grid: A [cW, cH, cD, 1] numpy array for the visibility values in the low-res culling grid. Note that this function will be adding visibility values into this grid. scene_params: A dict for scene specific params (bbox, rotation, resolution). grid_params: A dict with parameters describing the high-res voxel grid which the atlas is representing. """ origins, directions, _ = datasets.rays_from_camera( scene_params['_use_pixel_centers'], h, w, focal, np.expand_dims(camtoworld, 0)) if scene_params['ndc']: origins, directions = datasets.convert_to_ndc(origins, directions, focal, w, h) integrate_visibility_from_rays(origins.reshape(-1, 3), directions.reshape(-1, 3), alpha_grid, visibility_grid, scene_params, grid_params)
def post_process_render(viewdir_mlp, viewdir_mlp_params, rgb, alpha, h, w, focal, camtoworld, scene_params): """Post-processes a SNeRG render (background, view-dependence MLP). Composites the render onto the desired background color, then evaluates the view-dependence MLP for each pixel, and adds the specular residual. Args: viewdir_mlp: A nerf.model_utils.MLP that predicts the per-ray view-dependent residual color. viewdir_mlp_params: A dict containing the MLP parameters for the per-sample view-dependence MLP. rgb: A [H, W, 7] tensor containing the RGB and features accumulated at each pixel. alpha: A [H, W, 1] tensor containing the alpha accumulated at each pixel. h: The image height (pixels). w: The image width (pixels). focal: The image focal length (pixels). camtoworld: A numpy array of shape [4, 4] containing the camera-to-world transformation matrix for the camera. scene_params: A dict for scene specific params (bbox, rotation, resolution). Returns: A list containing post-processed images in the following order: the final output image (output_rgb), the alpha channel (alpha), the diffuse-only rgb image (rgb), the accumulated feature channels (features), and the specular residual from the view-dependence MLP (residual). """ if scene_params['white_bkgd']: rgb[Ellipsis, 0:3] = np.ones_like( rgb[Ellipsis, 0:3]) * (1.0 - alpha) + rgb[Ellipsis, 0:3] features = rgb[Ellipsis, 3:scene_params['_channels']] rgb = rgb[Ellipsis, 0:3] rgb_features = np.concatenate([rgb, features], -1) _, _, viewdirs = datasets.rays_from_camera( scene_params['_use_pixel_centers'], h, w, focal, np.expand_dims(camtoworld, 0)) viewdirs = viewdirs.reshape((rgb.shape[0], rgb.shape[1], 3)) residual = model_utils.viewdir_fn(viewdir_mlp, viewdir_mlp_params, rgb_features, viewdirs, scene_params) output_rgb = np.minimum(1.0, rgb + residual) return_list = [output_rgb, alpha, rgb, features, residual] return return_list
def atlas_raymarch_image_tf(h, w, focal, camtoworld, atlas_t, atlas_block_indices_t, atlas_params, scene_params, grid_params): """Fast ray marching through a SNeRG scene for an image. A convenient wrapper function around atlas_raymarch_rays_parallel_tf. Args: h: The image height (pixels). w: The image width (pixels). focal: The image focal length (pixels). camtoworld: A numpy array of shape [4, 4] containing the camera-to-world transformation matrix for the camera. atlas_t: A tensorflow tensor containing the texture atlas. atlas_block_indices_t: A tensorflow tensor containing the indirection grid. atlas_params: A dict with params for building and rendering with the 3D texture atlas. scene_params: A dict for scene specific params (bbox, rotation, resolution). grid_params: A dict with parameters describing the high-res voxel grid which the atlas is representing. Returns: rgb: A [h, w, C] np.array with the colors and features accumulated for each pixel. alpha: A [h, w, 1 ] np.array with the alpha value accumuated for each pixel. """ origins, directions, _ = datasets.rays_from_camera( scene_params['_use_pixel_centers'], h, w, focal, np.expand_dims(camtoworld, 0)) if scene_params['ndc']: origins, directions = datasets.convert_to_ndc(origins, directions, focal, w, h) return atlas_raymarch_rays_parallel_tf(h, w, origins, directions, atlas_t, atlas_block_indices_t, atlas_params, scene_params, grid_params)
def build_sharded_dataset_for_view_dependence(source_dataset, atlas_t, atlas_block_indices_t, atlas_params, scene_params, grid_params): """Builds a dataset that we can run the view-dependence MLP on. We ray march through a baked SNeRG model to generate images with RGB colors and features. These serve as the input for the view-dependence MLP which adds back the effects such as highlights. To make use of multi-host parallelism provided by JAX, this function shards the dataset, so that each host contains only a slice of the data. Args: source_dataset: The nerf.datasets.Dataset we should compute data for. atlas_t: A tensorflow tensor containing the texture atlas. atlas_block_indices_t: A tensorflow tensor containing the indirection grid. atlas_params: A dict with params for building and rendering with the 3D texture atlas. scene_params: A dict for scene specific params (bbox, rotation, resolution). grid_params: A dict with parameters describing the high-res voxel grid which the atlas is representing. Returns: rgb_data: The RGB (+ features) input data, stored as an (N/NUM_HOSTS, NUM_LOCAL_DEVICES, H, W, 7) numpy array. alpha_data: The alpha channel of the input data, stored as an (N/NUM_HOSTS, NUM_LOCAL_DEVICES, H, W, 1) numpy array. direction_data: The direction vectors for the input data, stored as an (N/NUM_HOSTS, NUM_LOCAL_DEVICES, H, W, 3) numpy array. ref_data: The reference RGB colors for each input data sample, stored as an (N/NUM_HOSTS, NUM_LOCAL_DEVICES, H, W, 3) numpy array. """ num_hosts = jax.host_count() num_local_devices = jax.local_device_count() host_id = jax.host_id() num_images = source_dataset.camtoworlds.shape[0] num_batches = math.ceil(num_images / num_hosts) num_batches = num_local_devices * math.ceil( num_batches / num_local_devices) rgb_list = [] alpha_list = [] viewdir_list = [] ref_list = [] for i in range(num_batches): base_index = i * num_hosts dataset_index = base_index + host_id rgb = np.zeros( (source_dataset.h, source_dataset.w, scene_params["_channels"]), dtype=np.float32) alpha = np.zeros((source_dataset.h, source_dataset.w, 1), dtype=np.float32) viewdirs = np.zeros((source_dataset.h, source_dataset.w, 3), dtype=np.float32) if dataset_index < num_images: rgb, alpha = rendering.atlas_raymarch_image_tf( source_dataset.h, source_dataset.w, source_dataset.focal, source_dataset.camtoworlds[dataset_index], atlas_t, atlas_block_indices_t, atlas_params, scene_params, grid_params) _, _, viewdirs = datasets.rays_from_camera( scene_params["_use_pixel_centers"], source_dataset.h, source_dataset.w, source_dataset.focal, np.expand_dims(source_dataset.camtoworlds[dataset_index], 0)) np_rgb = np.array(rgb).reshape( (source_dataset.h, source_dataset.w, scene_params["_channels"])) np_alpha = np.array(alpha).reshape( (source_dataset.h, source_dataset.w, 1)) np_viewdirs = viewdirs.reshape((np_rgb.shape[0], np_rgb.shape[1], 3)) if scene_params["white_bkgd"]: np_rgb[Ellipsis, 0:3] = np.ones_like(np_rgb[Ellipsis, 0:3]) * ( 1.0 - np_alpha) + np_rgb[Ellipsis, 0:3] rgb_list.append(np_rgb) alpha_list.append(np_alpha) viewdir_list.append(np_viewdirs) ref_list.append(source_dataset.images[dataset_index % num_images]) rgb_data = np.stack(rgb_list, 0).reshape( (-1, num_local_devices, source_dataset.h, source_dataset.w, scene_params["_channels"])) alpha_data = np.stack(alpha_list, 0).reshape( (-1, num_local_devices, source_dataset.h, source_dataset.w, 1)) viewdir_data = np.stack(viewdir_list, 0).reshape( (-1, num_local_devices, source_dataset.h, source_dataset.w, 3)) ref_data = np.stack(ref_list, 0).reshape( (-1, num_local_devices, source_dataset.h, source_dataset.w, 3)) return rgb_data, alpha_data, viewdir_data, ref_data