Esempio n. 1
0
  def create_cameras_dataset(
      self,
      cameras: Union[Iterable[tfcam.TFCamera], Iterable[gpath.GPath]],
      flatten=False,
      shuffle=False):
    """Creates a tf.data.Dataset from a list of cameras."""
    if isinstance(cameras[0], gpath.GPath) or isinstance(cameras[0], str):
      cameras = utils.parallel_map(self.load_camera, cameras)

    def _generator():
      for camera in cameras:
        yield {'camera_params': camera.get_parameters()}

    dataset = tf.data.Dataset.from_generator(
        _generator,
        output_signature={'camera_params': _TF_CAMERA_PARAMS_SIGNATURE})
    dataset = dataset.map(
        functools.partial(_camera_to_rays_fn, use_tf_camera=True), _TF_AUTOTUNE)

    if flatten:
      # Unbatch images to rows.
      dataset = dataset.unbatch()
      if shuffle:
        dataset = dataset.shuffle(20000)
      # Unbatch rows to rays.
      dataset = dataset.unbatch()
      if shuffle:
        dataset = dataset.shuffle(20000)

    return dataset
Esempio n. 2
0
  def _create_preloaded_dataset(self, item_ids, flatten=False, shuffle=False):
    """Crates a dataset completely preloaded in memory.

    This creates a tf.data.Dataset which is constructed by load all data
    into memory and pre-shuffling (if applicable). This is much faster than
    having tf.data.Dataset handle individual items.

    Args:
      item_ids: the item IDs to construct the datset with.
      flatten: whether to flatten the image dimensions.
      shuffle: whether to shuffle the dataset.

    Returns:
      A tf.data.Dataset instance.
    """
    load_fn = functools.partial(self.get_item)
    data_list = utils.parallel_map(load_fn, item_ids)
    data_list = [_camera_to_rays_fn(item) for item in data_list]
    data_dict = utils.tree_collate(data_list)

    num_examples = data_dict['origins'].shape[0]
    heights = [x.shape[0] for x in data_dict['origins']]
    widths = [x.shape[1] for x in data_dict['origins']]

    # Broadcast appearance ID to match ray shapes.
    if 'metadata' in data_dict:
      for metadata_key, metadata in data_dict['metadata'].items():
        data_dict['metadata'][metadata_key] = np.asarray([
            np.full((heights[i], widths[i], 1), fill_value=x)
            for i, x in enumerate(metadata)
        ])

    num_rays = int(sum([x * y for x, y in zip(heights, widths)]))
    shuffled_inds = self.rng.permutation(num_rays)

    logging.info('*** Loaded dataset items: num_rays=%d, num_examples=%d',
                 num_rays, num_examples)

    def _prepare_array(x):
      if not isinstance(x, np.ndarray):
        x = np.asarray(x)
      # Create last dimension if it doesn't exist.
      # The `and` part of the check ensures we're not touching ragged arrays.
      if x.ndim == 1 and x[0].ndim == 0:
        x = np.expand_dims(x, -1)
      if flatten:
        x = np.concatenate([x.reshape(-1, x.shape[-1]) for x in x], axis=0)
      if shuffle:
        x = x[shuffled_inds]
      return x

    out_dict = {}
    for key, value in data_dict.items():
      out_dict[key] = jax.tree_map(_prepare_array, value)

    return tf.data.Dataset.from_tensor_slices(out_dict)
Esempio n. 3
0
 def load_test_cameras(self, count=None):
     camera_dir = self.data_dir / "camera-paths" / self.test_camera_trajectory
     if not camera_dir.exists():
         logging.warning("test camera path does not exist: %s", str(camera_dir))
         return []
     camera_paths = sorted(camera_dir.glob(f"*{self.camera_ext}"))
     if count is not None:
         stride = max(1, len(camera_paths) // count)
         camera_paths = camera_paths[::stride]
     cameras = utils.parallel_map(self.load_camera, camera_paths)
     return cameras
Esempio n. 4
0
 def create_cameras_dataset(self,
                            cameras: Union[Iterable[cam.Camera],
                                           Iterable[gpath.GPath]],
                            flatten=False,
                            shuffle=False):
     if isinstance(cameras[0], gpath.GPath) or isinstance(cameras[0], str):
         cameras = utils.parallel_map(self.load_camera, cameras)
     data_dict = utils.tree_collate([camera_to_rays(c) for c in cameras])
     return dataset_from_dict(data_dict,
                              rng=self.rng,
                              flatten=flatten,
                              shuffle=shuffle)
Esempio n. 5
0
 def parallel_get_items(self, item_ids, scale_factor=1.0):
     """Load data dictionaries indexed by indices in parallel."""
     load_fn = functools.partial(self.get_item, scale_factor=scale_factor)
     data_list = utils.parallel_map(load_fn, item_ids)
     data_dict = utils.tree_collate(data_list)
     return data_dict