def _create_preloaded_dataset(self, item_ids, flatten=False, shuffle=False): """Crates a dataset completely preloaded in memory. This creates a tf.data.Dataset which is constructed by load all data into memory and pre-shuffling (if applicable). This is much faster than having tf.data.Dataset handle individual items. Args: item_ids: the item IDs to construct the datset with. flatten: whether to flatten the image dimensions. shuffle: whether to shuffle the dataset. Returns: A tf.data.Dataset instance. """ load_fn = functools.partial(self.get_item) data_list = utils.parallel_map(load_fn, item_ids) data_list = [_camera_to_rays_fn(item) for item in data_list] data_dict = utils.tree_collate(data_list) num_examples = data_dict['origins'].shape[0] heights = [x.shape[0] for x in data_dict['origins']] widths = [x.shape[1] for x in data_dict['origins']] # Broadcast appearance ID to match ray shapes. if 'metadata' in data_dict: for metadata_key, metadata in data_dict['metadata'].items(): data_dict['metadata'][metadata_key] = np.asarray([ np.full((heights[i], widths[i], 1), fill_value=x) for i, x in enumerate(metadata) ]) num_rays = int(sum([x * y for x, y in zip(heights, widths)])) shuffled_inds = self.rng.permutation(num_rays) logging.info('*** Loaded dataset items: num_rays=%d, num_examples=%d', num_rays, num_examples) def _prepare_array(x): if not isinstance(x, np.ndarray): x = np.asarray(x) # Create last dimension if it doesn't exist. # The `and` part of the check ensures we're not touching ragged arrays. if x.ndim == 1 and x[0].ndim == 0: x = np.expand_dims(x, -1) if flatten: x = np.concatenate([x.reshape(-1, x.shape[-1]) for x in x], axis=0) if shuffle: x = x[shuffled_inds] return x out_dict = {} for key, value in data_dict.items(): out_dict[key] = jax.tree_map(_prepare_array, value) return tf.data.Dataset.from_tensor_slices(out_dict)
def create_cameras_dataset(self, cameras: Union[Iterable[cam.Camera], Iterable[gpath.GPath]], flatten=False, shuffle=False): if isinstance(cameras[0], gpath.GPath) or isinstance(cameras[0], str): cameras = utils.parallel_map(self.load_camera, cameras) data_dict = utils.tree_collate([camera_to_rays(c) for c in cameras]) return dataset_from_dict(data_dict, rng=self.rng, flatten=flatten, shuffle=shuffle)
def parallel_get_items(self, item_ids, scale_factor=1.0): """Load data dictionaries indexed by indices in parallel.""" load_fn = functools.partial(self.get_item, scale_factor=scale_factor) data_list = utils.parallel_map(load_fn, item_ids) data_dict = utils.tree_collate(data_list) return data_dict