Ejemplo n.º 1
0
Archivo: train.py Proyecto: edhenry/bot
 def parse_dataset(dataset: tf.data.TFRecordDataset):
     """Parse entire TFRecord Dataset
     
     Arguments:
         dataset {tf.data.TFRecordDataset} -- dataset loaded from target filesystem
     """
     return dataset.map(parse_image_function)
Ejemplo n.º 2
0
    def post_process(self, dataset: tf.data.TFRecordDataset,
                     n_parallel_calls: int):
        dataset = super().post_process(dataset, n_parallel_calls)

        dataset = dataset.map(self.add_recovery_probabilities)
        dataset = dataset.filter(is_stuck)
        return dataset
Ejemplo n.º 3
0
    def post_process(self, dataset: tf.data.TFRecordDataset,
                     n_parallel_calls: int, **kwargs):
        dataset = super().post_process(dataset, n_parallel_calls)

        if self.use_gt_rope:
            print(Fore.GREEN + "Using ground-truth rope state")
            dataset = dataset.map(use_gt_rope)

        return dataset
Ejemplo n.º 4
0
    def post_process(self, dataset: tf.data.TFRecordDataset,
                     n_parallel_calls: int):
        scenario_metadata = self.scenario_metadata

        def _add_scenario_metadata(example: Dict):
            example.update(scenario_metadata)
            return example

        dataset = dataset.map(_add_scenario_metadata)
        return dataset
Ejemplo n.º 5
0
    def post_process(self, dataset: tf.data.TFRecordDataset,
                     n_parallel_calls: int):
        dataset = super().post_process(dataset, n_parallel_calls)

        def _add_time(example: Dict):
            # this function is called before batching occurs, so the first dimension should be time
            example['time'] = tf.cast(self.horizon, tf.int64)
            return example

        # dataset = dataset.map(_add_time)

        threshold = self.threshold

        def _label(example: Dict):
            add_label(example, threshold)
            return example

        if not self.old_compat:
            dataset = dataset.map(_label)

        if self.use_gt_rope:
            dataset = dataset.map(use_gt_rope)

        return dataset
Ejemplo n.º 6
0
    def parse(self, dataset: tf.data.TFRecordDataset, split: Split = None) -> tf.data.Dataset:
        # TODO: Consider split
        feature_specs = self.feature_specs  # type: Dict[str, FeatureSpec]
        features = {k: feature_spec.features for k, feature_spec in feature_specs.items()}
        flatten_features = flatten_nested_dict(features)

        def parse_fn(record: tf.Tensor):
            parsed = tf.io.parse_single_example(record, flatten_features)

            unflatten = {key: [] for key in features.keys()}

            # Unflatten
            for flatten_key, tensor in parsed.items():
                key, inner_key = flatten_key.split('/')

                unflatten[key].append((inner_key, tensor))

            return {k: feature_specs[k].build_tensor(dict(tups)) for k, tups in unflatten.items()}

        return dataset.map(parse_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)