def compute_object_intersect_tensors(name, ray_batch, scene_info, far,
                                     object2padding, swap_object_yz, **kwargs):
    """Compute intersecting rays."""
    rays_o = ray_utils.extract_slice_from_ray_batch(  # [R, 3]
        ray_batch=ray_batch,  # [R, M]
        key='origin')
    rays_d = ray_utils.extract_slice_from_ray_batch(  # [R, 3]
        ray_batch=ray_batch,  # [R, M]
        key='direction')
    rays_far = ray_utils.extract_slice_from_ray_batch(  # [R, 3]
        ray_batch=ray_batch,  # [R, M]
        key='far')
    rays_sid = ray_utils.extract_slice_from_ray_batch(  # [R, 1]
        ray_batch=ray_batch,  # [R, M]
        key='metadata')

    (
        box_dims,  # [R, 3], [R, 3], [R, 3, 3]
        box_center,
        box_rotation) = scene_utils.extract_object_boxes_for_scenes(
            name=name,
            scene_info=scene_info,
            sids=rays_sid,  # [R, 1]
            padding=object2padding[name],
            swap_yz=swap_object_yz,
            box_delta_t=kwargs['box_delta_t'])

    # Compute ray-bbox intersections.
    intersect_bounds, intersect_indices, intersect_mask = (  # [R', 2],[R',],[R,]
        box_utils.compute_ray_bbox_bounds_pairwise(
            rays_o=rays_o,  # [R, 3]
            rays_d=rays_d,  # [R, 3]
            rays_far=rays_far,  # [R, 1]
            box_length=box_dims[:, 0],  # [R,]
            box_width=box_dims[:, 1],  # [R,]
            box_height=box_dims[:, 2],  # [R,]
            box_center=box_center,  # [R, 3]
            box_rotation=box_rotation,  # [R, 3, 3]
            far_limit=far))

    # Apply the intersection mask to the ray batch.
    intersect_ray_batch = apply_intersect_mask_to_tensors(  # [R', M]
        intersect_mask=intersect_mask,  # [R,]
        tensors=[ray_batch])[0]  # [R, M]

    # Update the near and far bounds of the ray batch with the intersect bounds.
    intersect_ray_batch = ray_utils.update_ray_batch_bounds(  # [R', M]
        ray_batch=intersect_ray_batch,  # [R', M]
        bounds=intersect_bounds)  # [R', 2]
    return intersect_ray_batch, intersect_indices  # [R', M], [R', 1]
def network_query_fn_helper_nodirs(pts,
                                   ray_batch,
                                   network,
                                   network_query_fn,
                                   use_viewdirs,
                                   use_lightdirs,
                                   use_lightdir_norm,
                                   scene_info,
                                   use_random_lightdirs,
                                   light_origins=None,
                                   **kwargs):
    """Same as network_query_fn_helper, but without input directions."""
    _ = kwargs

    if not use_viewdirs:
        viewdirs = None
    if not use_lightdirs:
        lightdirs = None

    # Extract unit-normalized viewing direction.
    if use_viewdirs:
        viewdirs = ray_batch[:, 8:11]  # [R, 3]
        viewdirs = broadcast_samples_dim(x=viewdirs, target=pts)  # [R, S, 3]
    else:
        viewdirs = None

    # Compute the light directions.
    # if use_lightdirs:
    light_ray_batch, lightdirs = compute_lightdirs(  # [R, S, 3]
        pts=pts,
        metadata=ray_utils.extract_slice_from_ray_batch(
            ray_batch, key='metadata', use_viewdirs=use_viewdirs),
        scene_info=scene_info,
        use_lightdir_norm=use_lightdir_norm,
        use_random_lightdirs=use_random_lightdirs,
        light_pos=kwargs['light_pos'],
        light_origins=light_origins)
    # else:
    #   light_ray_batch = None
    #   lightdirs = None

    # Extract additional per-ray metadata.
    rays_data = ray_utils.extract_slice_from_ray_batch(
        ray_batch, key='example_id', use_viewdirs=use_viewdirs)

    # Query NeRF for the corresponding densities for the light points.
    raw = network_query_fn(pts, viewdirs, lightdirs, rays_data, network)
    return light_ray_batch, raw
def create_w2o_transformations_tensors(name, scene_info, ray_batch,
                                       use_viewdirs, box_delta_t):
    """Create transformation tensor from world to object space."""
    metadata = ray_utils.extract_slice_from_ray_batch(
        ray_batch, key='metadata', use_viewdirs=use_viewdirs)  # [R, 1]
    w2o_rt_per_scene, w2o_r_per_scene = (
        scene_utils.extract_w2o_transformations_per_scene(
            name=name, scene_info=scene_info, box_delta_t=box_delta_t))
    w2o_rt = tf.gather_nd(  # [R, 4, 4]
        params=w2o_rt_per_scene,  # [N_scenes, 4, 4]
        indices=metadata)  # [R, 1]
    w2o_r = tf.gather_nd(  # [R, 4, 4]
        params=w2o_r_per_scene,  # [N_scenes, 4, 4]
        indices=metadata)  # [R, 1]
    return w2o_rt, w2o_r
Пример #4
0
def compute_view_light_dirs(ray_batch, pts, scene_info, use_viewdirs,
                            use_lightdir_norm, use_random_lightdirs, light_pos):
  """Compute viewing and lighting directions."""
  viewdirs = ray_batch[:, 8:11]  # [R, 3]
  metadata = ray_utils.extract_slice_from_ray_batch(
      ray_batch, key='metadata', use_viewdirs=use_viewdirs)  # [R, 1]
  viewdirs = broadcast_samples_dim(x=viewdirs, target=pts)  # [R, S, 3]
  light_ray_batch, lightdirs = compute_lightdirs(
      pts=pts,
      metadata=metadata,
      scene_info=scene_info,
      use_lightdir_norm=use_lightdir_norm,
      use_random_lightdirs=use_random_lightdirs,
      light_pos=light_pos)  # [R, S, 3]
  return viewdirs, light_ray_batch, lightdirs
def network_query_fn_helper(pts, ray_batch, network, network_query_fn,
                            viewdirs, lightdirs, use_viewdirs, use_lightdirs):
    """Query the NeRF network."""
    if not use_viewdirs:
        viewdirs = None
    if not use_lightdirs:
        lightdirs = None
    # Extract unit-normalized viewing direction.
    # [n_rays, 3]
    # viewdirs = ray_batch[:, 8:11] if use_viewdirs else None

    # Extract additional per-ray metadata.
    # [n_rays, metadata_channels]
    rays_data = ray_utils.extract_slice_from_ray_batch(
        ray_batch, key='example_id', use_viewdirs=use_viewdirs)

    # Query NeRF for the corresponding densities for the light points.
    raw = network_query_fn(pts, viewdirs, lightdirs, rays_data, network)
    return raw
Пример #6
0
def create_shadow_ray_batch(ray_batch, pts, scene_info, light_pos):
  """Create batch for shadow rays.

  Args:
    ray_batch: [?, M] tf.float32. Primary ray batch.
    pts: [?, S, 3] tf.float32. Primary points.
    scene_info: Dict. Scene information.
    light_pos:

  Returns:
    shadow_ray_batch: [?S, M]
  """
  # num_rays = tf.shape(pts)[0]  # ?
  num_primary_samples = pts.shape[1]  # S

  pts = tf.reshape(pts, [-1, 3])  # [?S, 3]

  # Prepare light positions.
  light_positions = ray_utils.extract_light_positions_for_rays(
      ray_batch=ray_batch, scene_info=scene_info, light_pos=light_pos)

  # Get ray scene IDs.
  rays_sid = ray_utils.extract_slice_from_ray_batch(
      ray_batch=ray_batch, key='metadata')

  # Repeat ray-level information by the number of primary samples per ray.
  light_positions = tf.tile(
      light_positions[:, None, :],  # [?, S, 3]
      [1, num_primary_samples, 1])
  rays_sid = tf.tile(
      rays_sid[:, None, :],  # [?, S, 1]
      [1, num_primary_samples, 1])
  light_positions = tf.reshape(light_positions, [-1, 3])  # [?S, 3]
  rays_sid = tf.reshape(rays_sid, [-1, 1])  # [?S, 1]

  # Create the ray batch.
  shadow_ray_batch = ray_utils.create_ray_batch(
      rays_o=light_positions, rays_dst=pts, rays_sid=rays_sid)
  return shadow_ray_batch