def create_cameras_dataset( self, cameras: Union[Iterable[tfcam.TFCamera], Iterable[gpath.GPath]], flatten=False, shuffle=False): """Creates a tf.data.Dataset from a list of cameras.""" if isinstance(cameras[0], gpath.GPath) or isinstance(cameras[0], str): cameras = utils.parallel_map(self.load_camera, cameras) def _generator(): for camera in cameras: yield {'camera_params': camera.get_parameters()} dataset = tf.data.Dataset.from_generator( _generator, output_signature={'camera_params': _TF_CAMERA_PARAMS_SIGNATURE}) dataset = dataset.map( functools.partial(_camera_to_rays_fn, use_tf_camera=True), _TF_AUTOTUNE) if flatten: # Unbatch images to rows. dataset = dataset.unbatch() if shuffle: dataset = dataset.shuffle(20000) # Unbatch rows to rays. dataset = dataset.unbatch() if shuffle: dataset = dataset.shuffle(20000) return dataset
def _create_preloaded_dataset(self, item_ids, flatten=False, shuffle=False): """Crates a dataset completely preloaded in memory. This creates a tf.data.Dataset which is constructed by load all data into memory and pre-shuffling (if applicable). This is much faster than having tf.data.Dataset handle individual items. Args: item_ids: the item IDs to construct the datset with. flatten: whether to flatten the image dimensions. shuffle: whether to shuffle the dataset. Returns: A tf.data.Dataset instance. """ load_fn = functools.partial(self.get_item) data_list = utils.parallel_map(load_fn, item_ids) data_list = [_camera_to_rays_fn(item) for item in data_list] data_dict = utils.tree_collate(data_list) num_examples = data_dict['origins'].shape[0] heights = [x.shape[0] for x in data_dict['origins']] widths = [x.shape[1] for x in data_dict['origins']] # Broadcast appearance ID to match ray shapes. if 'metadata' in data_dict: for metadata_key, metadata in data_dict['metadata'].items(): data_dict['metadata'][metadata_key] = np.asarray([ np.full((heights[i], widths[i], 1), fill_value=x) for i, x in enumerate(metadata) ]) num_rays = int(sum([x * y for x, y in zip(heights, widths)])) shuffled_inds = self.rng.permutation(num_rays) logging.info('*** Loaded dataset items: num_rays=%d, num_examples=%d', num_rays, num_examples) def _prepare_array(x): if not isinstance(x, np.ndarray): x = np.asarray(x) # Create last dimension if it doesn't exist. # The `and` part of the check ensures we're not touching ragged arrays. if x.ndim == 1 and x[0].ndim == 0: x = np.expand_dims(x, -1) if flatten: x = np.concatenate([x.reshape(-1, x.shape[-1]) for x in x], axis=0) if shuffle: x = x[shuffled_inds] return x out_dict = {} for key, value in data_dict.items(): out_dict[key] = jax.tree_map(_prepare_array, value) return tf.data.Dataset.from_tensor_slices(out_dict)
def load_test_cameras(self, count=None): camera_dir = self.data_dir / "camera-paths" / self.test_camera_trajectory if not camera_dir.exists(): logging.warning("test camera path does not exist: %s", str(camera_dir)) return [] camera_paths = sorted(camera_dir.glob(f"*{self.camera_ext}")) if count is not None: stride = max(1, len(camera_paths) // count) camera_paths = camera_paths[::stride] cameras = utils.parallel_map(self.load_camera, camera_paths) return cameras
def create_cameras_dataset(self, cameras: Union[Iterable[cam.Camera], Iterable[gpath.GPath]], flatten=False, shuffle=False): if isinstance(cameras[0], gpath.GPath) or isinstance(cameras[0], str): cameras = utils.parallel_map(self.load_camera, cameras) data_dict = utils.tree_collate([camera_to_rays(c) for c in cameras]) return dataset_from_dict(data_dict, rng=self.rng, flatten=flatten, shuffle=shuffle)
def parallel_get_items(self, item_ids, scale_factor=1.0): """Load data dictionaries indexed by indices in parallel.""" load_fn = functools.partial(self.get_item, scale_factor=scale_factor) data_list = utils.parallel_map(load_fn, item_ids) data_dict = utils.tree_collate(data_list) return data_dict