def parse_dataset(dataset: tf.data.TFRecordDataset): """Parse entire TFRecord Dataset Arguments: dataset {tf.data.TFRecordDataset} -- dataset loaded from target filesystem """ return dataset.map(parse_image_function)
def post_process(self, dataset: tf.data.TFRecordDataset, n_parallel_calls: int): dataset = super().post_process(dataset, n_parallel_calls) dataset = dataset.map(self.add_recovery_probabilities) dataset = dataset.filter(is_stuck) return dataset
def post_process(self, dataset: tf.data.TFRecordDataset, n_parallel_calls: int, **kwargs): dataset = super().post_process(dataset, n_parallel_calls) if self.use_gt_rope: print(Fore.GREEN + "Using ground-truth rope state") dataset = dataset.map(use_gt_rope) return dataset
def post_process(self, dataset: tf.data.TFRecordDataset, n_parallel_calls: int): scenario_metadata = self.scenario_metadata def _add_scenario_metadata(example: Dict): example.update(scenario_metadata) return example dataset = dataset.map(_add_scenario_metadata) return dataset
def post_process(self, dataset: tf.data.TFRecordDataset, n_parallel_calls: int): dataset = super().post_process(dataset, n_parallel_calls) def _add_time(example: Dict): # this function is called before batching occurs, so the first dimension should be time example['time'] = tf.cast(self.horizon, tf.int64) return example # dataset = dataset.map(_add_time) threshold = self.threshold def _label(example: Dict): add_label(example, threshold) return example if not self.old_compat: dataset = dataset.map(_label) if self.use_gt_rope: dataset = dataset.map(use_gt_rope) return dataset
def parse(self, dataset: tf.data.TFRecordDataset, split: Split = None) -> tf.data.Dataset: # TODO: Consider split feature_specs = self.feature_specs # type: Dict[str, FeatureSpec] features = {k: feature_spec.features for k, feature_spec in feature_specs.items()} flatten_features = flatten_nested_dict(features) def parse_fn(record: tf.Tensor): parsed = tf.io.parse_single_example(record, flatten_features) unflatten = {key: [] for key in features.keys()} # Unflatten for flatten_key, tensor in parsed.items(): key, inner_key = flatten_key.split('/') unflatten[key].append((inner_key, tensor)) return {k: feature_specs[k].build_tensor(dict(tups)) for k, tups in unflatten.items()} return dataset.map(parse_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)