Example #1
0
  def _create_preloaded_dataset(self, item_ids, flatten=False, shuffle=False):
    """Crates a dataset completely preloaded in memory.

    This creates a tf.data.Dataset which is constructed by load all data
    into memory and pre-shuffling (if applicable). This is much faster than
    having tf.data.Dataset handle individual items.

    Args:
      item_ids: the item IDs to construct the datset with.
      flatten: whether to flatten the image dimensions.
      shuffle: whether to shuffle the dataset.

    Returns:
      A tf.data.Dataset instance.
    """
    load_fn = functools.partial(self.get_item)
    data_list = utils.parallel_map(load_fn, item_ids)
    data_list = [_camera_to_rays_fn(item) for item in data_list]
    data_dict = utils.tree_collate(data_list)

    num_examples = data_dict['origins'].shape[0]
    heights = [x.shape[0] for x in data_dict['origins']]
    widths = [x.shape[1] for x in data_dict['origins']]

    # Broadcast appearance ID to match ray shapes.
    if 'metadata' in data_dict:
      for metadata_key, metadata in data_dict['metadata'].items():
        data_dict['metadata'][metadata_key] = np.asarray([
            np.full((heights[i], widths[i], 1), fill_value=x)
            for i, x in enumerate(metadata)
        ])

    num_rays = int(sum([x * y for x, y in zip(heights, widths)]))
    shuffled_inds = self.rng.permutation(num_rays)

    logging.info('*** Loaded dataset items: num_rays=%d, num_examples=%d',
                 num_rays, num_examples)

    def _prepare_array(x):
      if not isinstance(x, np.ndarray):
        x = np.asarray(x)
      # Create last dimension if it doesn't exist.
      # The `and` part of the check ensures we're not touching ragged arrays.
      if x.ndim == 1 and x[0].ndim == 0:
        x = np.expand_dims(x, -1)
      if flatten:
        x = np.concatenate([x.reshape(-1, x.shape[-1]) for x in x], axis=0)
      if shuffle:
        x = x[shuffled_inds]
      return x

    out_dict = {}
    for key, value in data_dict.items():
      out_dict[key] = jax.tree_map(_prepare_array, value)

    return tf.data.Dataset.from_tensor_slices(out_dict)
Example #2
0
 def create_cameras_dataset(self,
                            cameras: Union[Iterable[cam.Camera],
                                           Iterable[gpath.GPath]],
                            flatten=False,
                            shuffle=False):
     if isinstance(cameras[0], gpath.GPath) or isinstance(cameras[0], str):
         cameras = utils.parallel_map(self.load_camera, cameras)
     data_dict = utils.tree_collate([camera_to_rays(c) for c in cameras])
     return dataset_from_dict(data_dict,
                              rng=self.rng,
                              flatten=flatten,
                              shuffle=shuffle)
Example #3
0
 def parallel_get_items(self, item_ids, scale_factor=1.0):
     """Load data dictionaries indexed by indices in parallel."""
     load_fn = functools.partial(self.get_item, scale_factor=scale_factor)
     data_list = utils.parallel_map(load_fn, item_ids)
     data_dict = utils.tree_collate(data_list)
     return data_dict