示例#1
0
    def _prepare_dataset(
        self,
        dataset: tf.data.Dataset,
        shuffle: bool = False,
        augment: bool = False
    ) -> tf.data.Dataset:
        preprocessing_model = self._build_preprocessing()
        dataset = dataset.map(
            map_func=lambda x, y: (preprocessing_model(x, training=False), y),
            num_parallel_calls=tf.data.experimental.AUTOTUNE
        )

        if shuffle:
            dataset = dataset.shuffle(buffer_size=1_000)

        dataset = dataset.batch(batch_size=self.batch_size)

        if augment:
            data_augmentation_model = self._build_data_augmentation()
            dataset = dataset.map(
                map_func=lambda x, y: (data_augmentation_model(x, training=False), y),
                num_parallel_calls=tf.data.experimental.AUTOTUNE
            )

        return dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
示例#2
0
    def process(self, dataset: tf.data.Dataset, batch_size: int):
        dataset = dataset.map(self.parse, num_parallel_calls=AUTOTUNE)

        if self.cache:
            dataset = dataset.cache()

        if self.shuffle:
            dataset = dataset.shuffle(self.buffer_size,
                                      reshuffle_each_iteration=True)

        # PADDED BATCH the dataset
        dataset = dataset.padded_batch(
            batch_size=batch_size,
            padded_shapes=(
                tf.TensorShape([]),
                tf.TensorShape(self.speech_featurizer.shape),
                tf.TensorShape([]),
                tf.TensorShape([None]),
                tf.TensorShape([]),
                tf.TensorShape([None]),
                tf.TensorShape([]),
            ),
            padding_values=("", 0., 0, self.text_featurizer.blank, 0,
                            self.text_featurizer.blank, 0),
            drop_remainder=self.drop_remainder)

        # PREFETCH to improve speed of input length
        dataset = dataset.prefetch(AUTOTUNE)
        self.total_steps = get_num_batches(self.total_steps, batch_size)
        return dataset
示例#3
0
def create_dataset(dataset: tf.data.Dataset, num_classes: int,
                   is_training: bool) -> tf.data.Dataset:
  """Produces a full, augmented dataset from the inptu dataset."""
  _, _, resolution, _ = efficientnet_builder.efficientnet_params(
      FLAGS.model_name)

  def process_data(image, label):
    image = preprocessing.preprocess_image(
        image,
        is_training=is_training,
        use_bfloat16=FLAGS.strategy == 'tpus',
        image_size=resolution,
        augment_name=FLAGS.augment_name,
        randaug_num_layers=FLAGS.randaug_num_layers,
        randaug_magnitude=FLAGS.randaug_magnitude,
        resize_method=None)

    label = tf.one_hot(label, num_classes)
    return image, label

  dataset = dataset.map(
      process_data, num_parallel_calls=tf.data.experimental.AUTOTUNE)
  dataset = dataset.batch(FLAGS.batch_size, drop_remainder=True)
  dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)

  return dataset
示例#4
0
    def get_bucket_iter(self,
                        dataset: tf.data.Dataset,
                        batch_size=32,
                        train=True) -> tf.data.Dataset:
        padded_shapes = self._padded_shapes()
        padding_values = self._padding_values()
        if train:
            bucket_boundaries = self._bucket_boundaries(batch_size)
            bucket_batch_sizes = [batch_size] * (len(bucket_boundaries) + 1)
            dataset = dataset.apply(
                tf.data.experimental.bucket_by_sequence_length(
                    self.element_length_func,
                    bucket_boundaries,
                    bucket_batch_sizes,
                    padded_shapes=padded_shapes,
                    padding_values=padding_values))
            dataset = dataset.shuffle(100)
        else:
            dataset = dataset.padded_batch(batch_size,
                                           padded_shapes=padded_shapes)

        dataset = dataset.map(self._collate_fn,
                              num_parallel_calls=tf.data.experimental.AUTOTUNE)
        dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
        return dataset
示例#5
0
def memoize(dataset: tf.data.Dataset) -> tf.data.Dataset:
    data = []
    with tf.Graph().as_default(), tf.Session(
            config=utils.get_config()) as session:
        dataset = dataset.prefetch(16)
        it = dataset.make_one_shot_iterator().get_next()
        try:
            while 1:
                data.append(session.run(it))
        except tf.errors.OutOfRangeError:
            pass
    images = np.stack([x['image'] for x in data])
    labels = np.stack([x['label'] for x in data])

    def tf_get(index):
        def get(index):
            return images[index], labels[index]

        image, label = tf.py_func(get, [index], [tf.float32, tf.int64])
        return dict(image=image, label=label)

    dataset = tf.data.Dataset.range(len(data)).repeat()
    dataset = dataset.shuffle(
        len(data) if len(data) < FLAGS.shuffle else FLAGS.shuffle)
    return dataset.map(tf_get)
示例#6
0
def get_augmented_data(
    dataset: tf.data.Dataset,
    batch_size: int,
    map_func: Callable,
    shuffle_buffer: Optional[int] = None,
    shuffle_seed: Optional[int] = None,
    augment_seed: Optional[int] = None,
    use_stateless_map: bool = False,
) -> RepeatedData:
    if shuffle_buffer is not None:
        dataset = dataset.shuffle(shuffle_buffer, seed=shuffle_seed)
    dataset = dataset.batch(batch_size)
    steps_per_epoch = tf.keras.backend.get_value(dataset.cardinality())
    # repeat before map so stateless map is different across epochs
    dataset = dataset.repeat()
    AUTOTUNE = tf.data.experimental.AUTOTUNE
    if use_stateless_map:
        dataset = dataset.apply(
            tfrng.data.stateless_map(
                map_func,
                seed=augment_seed,
                num_parallel_calls=AUTOTUNE,
            ))
    else:
        # if map_func has random elements this won't be deterministic
        dataset = dataset.map(map_func, num_parallel_calls=AUTOTUNE)
    dataset = dataset.prefetch(AUTOTUNE)
    return RepeatedData(dataset, steps_per_epoch)
示例#7
0
文件: Task.py 项目: nitrogenase/TAPE
    def prepare_dataset(self,
                        dataset: tf.data.Dataset,
                        buckets: List[int],
                        batch_sizes: List[int],
                        shuffle: bool = False) -> tf.data.Dataset:
        dataset = dataset.map(self._deserialization_func,
                              num_parallel_calls=128)

        buckets_array = np.array(buckets)
        batch_sizes_array = np.array(batch_sizes)

        if np.any(batch_sizes_array == 0) and shuffle:
            iszero = np.where(batch_sizes_array == 0)[0][0]
            filterlen = buckets_array[iszero - 1]
            print("Filtering sequences of length {}".format(filterlen))
            dataset = dataset.filter(
                lambda example: example['protein_length'] < filterlen)
        else:
            batch_sizes_array[batch_sizes_array <= 0] = 1

        dataset = dataset.shuffle(1024) if shuffle else dataset.prefetch(1024)
        batch_fun = tf.data.experimental.bucket_by_sequence_length(
            operator.itemgetter('protein_length'), buckets_array,
            batch_sizes_array)
        dataset = dataset.apply(batch_fun)
        return dataset
def create_dataset(dataset: tf.data.Dataset) -> tf.data.Dataset:
    dataset = dataset.map(normalize_img,
                          num_parallel_calls=tf.data.experimental.AUTOTUNE)
    dataset = dataset.batch(FLAGS.batch_size, drop_remainder=True)
    dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)

    return dataset
示例#9
0
def prepare_Dataset(dataset: tf.data.Dataset,
                    shuffle: bool = False,
                    augment: bool = False) -> tf.data.Dataset:
    """Prepare the dataset object with preprocessing and data augmentation.

    Parameters
    ----------
    dataset : tf.data.Dataset
        The dataset object
    shuffle : bool, optional
        Whether to shuffle the dataset, by default False
    augment : bool, optional
        Whether to augment the train dataset, by default False

    Returns
    -------
    tf.data.Dataset
        The prepared dataset
    """
    preprocessing_model = build_preprocessing()
    dataset = dataset.map(map_func=lambda x, y: (preprocessing_model(x), y),
                          num_parallel_calls=AUTOTUNE)

    if shuffle:
        dataset = dataset.shuffle(buffer_size=1_000)

    dataset = dataset.batch(batch_size=BATCH_SIZE)

    if augment:
        data_augmentation_model = build_data_augmentation()
        dataset = dataset.map(map_func=lambda x, y:
                              (data_augmentation_model(x), y),
                              num_parallel_calls=AUTOTUNE)

    return dataset.prefetch(buffer_size=AUTOTUNE)
示例#10
0
def preprocess_dataset(dataset: tf.data.Dataset, batch_size: int, n_step_returns: int, discount: float):
  d_len = sum([1 for _ in dataset])
  dataset = dataset.map(lambda *x:
                               n_step_transition_from_episode(*x, n_step=n_step_returns,
                                                              additional_discount=discount))
  dataset = dataset.repeat().shuffle(d_len).batch(batch_size, drop_remainder=True)
  dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
  return dataset
示例#11
0
def preprocessing(dsData: tf.data.Dataset, batch_size = 32, window_size = 20):
    dsData = dsData.window(window_size+1, shift=1, drop_remainder=True)
    dsData = dsData.flat_map(lambda win: win.batch(window_size+1))
    dsData = dsData.map(lambda x: (x[:-1], x[-1]))
    dsData = dsData.shuffle(1000)
    dsData = dsData.batch(batch_size)
    dsData = dsData.prefetch(1)
    return dsData
示例#12
0
  def pipeline(self, dataset: tf.data.Dataset) -> tf.data.Dataset:
    """Build a pipeline fetching, shuffling, and preprocessing the dataset.

    Args:
      dataset: A `tf.data.Dataset` that loads raw files.

    Returns:
      A TensorFlow dataset outputting batched images and labels.
    """
    if self._num_gpus > 1:
      dataset = dataset.shard(self._num_gpus, hvd.rank())

    if self.is_training:
      # Shuffle the input files.
      dataset.shuffle(buffer_size=self._file_shuffle_buffer_size)

    if self.is_training and not self._cache:
      dataset = dataset.repeat()

    # Read the data from disk in parallel
    dataset = dataset.interleave(
        tf.data.TFRecordDataset,
        cycle_length=10,
        block_length=1,
        num_parallel_calls=tf.data.experimental.AUTOTUNE)

    if self._cache:
      dataset = dataset.cache()

    if self.is_training:
      dataset = dataset.shuffle(self._shuffle_buffer_size)
      dataset = dataset.repeat()

    # Parse, pre-process, and batch the data in parallel
    preprocess = self.parse_record
    dataset = dataset.map(preprocess,
                          num_parallel_calls=tf.data.experimental.AUTOTUNE)

    if self._num_gpus > 1:
      # The batch size of the dataset will be multiplied by the number of
      # replicas automatically when strategy.distribute_datasets_from_function
      # is called, so we use local batch size here.
      dataset = dataset.batch(self.local_batch_size,
                              drop_remainder=self.is_training)
    else:
      dataset = dataset.batch(self.global_batch_size,
                              drop_remainder=self.is_training)

    # Apply Mixup
    mixup_alpha = self.mixup_alpha if self.is_training else 0.0
    dataset = dataset.map(
        functools.partial(self.mixup, self.local_batch_size, mixup_alpha),
        num_parallel_calls=64)

    # Prefetch overlaps in-feed with training
    dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)

    return dataset
示例#13
0
def configure_for_performance(ds: tf.data.Dataset) -> tf.data.Dataset:
    """Function applies batch() and prefetch() functions
    to the dataset to optimize data processing.
    :param ds: TensorFlow Dataset object
    :return Batched TensorFlow Dataset object
    """
    ds = ds.batch(BATCH_SIZE)
    ds = ds.prefetch(buffer_size=AUTOTUNE).cache()
    return ds
示例#14
0
    def _prepare_dataset(self,
                         dataset: tf.data.Dataset,
                         shuffle: bool = False,
                         augment: bool = False) -> tf.data.Dataset:
        if shuffle:
            dataset = dataset.shuffle(buffer_size=1_000)

        dataset = dataset.batch(batch_size=self.batch_size)

        return dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
示例#15
0
 def prepare_dataset(self,
                     dataset: tf.data.Dataset,
                     buckets: List[int],
                     batch_sizes: List[int],
                     shuffle: bool = False) -> tf.data.Dataset:
     dataset = dataset.map(self._deserialization_func, 128)
     dataset = dataset.shuffle(1024) if shuffle else dataset.prefetch(1024)
     batch_fun = tf.data.experimental.bucket_by_sequence_length(
         lambda example: tf.maximum(example['first']['protein_length'],
                                    example['second']['protein_length']),
         buckets, batch_sizes)
     dataset = dataset.apply(batch_fun)
     return dataset
    def _hook_dataset_post_precessing(self, ds: tf.data.Dataset,
                                      batch_size: int):
        ds = ds.padded_batch(batch_size=batch_size, padded_shapes=[None])
        ds = ds.prefetch(tf.data.experimental.AUTOTUNE)

        def add_padding_mask(source):
            enc_padding_mask = create_padding_mask_fm(source)
            return (source, enc_padding_mask), source

        ds = ds.map(map_func=add_padding_mask)  # Add pad mask

        # For masked language model task, all datasets are masked
        return self._apply_mask_for_mlm(ds=ds, vocab_size=self._vocab_size)
示例#17
0
        def mixup(dataset: tf.data.Dataset) -> tf.data.Dataset:
            def mixup_map(data, shuffled):
                dist = tfp.distributions.Beta([alpha], [alpha])
                beta = dist.sample([1])[0][0]

                ret = {}
                ret["image"] = (data["image"]) * beta + (shuffled["image"] * (1-beta))
                ret["label"] = (data["label"]) * beta + (shuffled["label"] * (1-beta))
                return ret

            shuffle_dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE).shuffle(mixup_size)
            zipped = tf.data.Dataset.zip((dataset, shuffle_dataset))
            return zipped.map(mixup_map, num_parallel_calls=tf.data.experimental.AUTOTUNE)
示例#18
0
    def fit(self,
            model: Model,
            train_dataset: tf.data.Dataset,
            validation_dataset: Optional[tf.data.Dataset] = None,
            test_dataset: Optional[tf.data.Dataset] = None,
            epochs: int = 10,
            batch_size: int = 1,
            **kwargs) -> History:
        """Train the model with given setting, perform evaluation if
            `test_dataset` provided.

        Args:
            model: the model to be fit.
            train_dataset: a training dataset passed to `model.fit()`.
            validation_dataset:
                validation dataset passed to `model.fit()`. Defaults to None.
            test_dataset:
                test dataset passed to `model.evaluate()`. Defaults to None.
            epochs: number of epochs to train. Defaults to 10.
            batch_size: number of samples per batch. Defaults to 1.

        Returns:
            History: history of training.
        """

        out_shape = get_output_shape(model, train_dataset)

        train_dataset = train_dataset.map(
            layers.crop_labels_to_shape(out_shape)).batch(batch_size)
        train_dataset = train_dataset.prefetch(
            buffer_size=tf.data.experimental.AUTOTUNE)
        if validation_dataset:
            validation_dataset = validation_dataset.map(
                layers.crop_labels_to_shape(out_shape)).batch((batch_size))

        callbacks = self.build_callbacks()

        history = model.fit(train_dataset,
                            validation_data=validation_dataset,
                            epochs=epochs,
                            callbacks=callbacks,
                            **kwargs)

        if test_dataset:
            test_dataset = test_dataset\
                .map(layers.crop_labels_to_shape(out_shape))\
                .batch(batch_size)
            model.evaluate(test_dataset)

        return history
示例#19
0
    def transform_dataset(self, ds_input: tf.data.Dataset) -> tf.data.Dataset:
        """Create a dataset with prefetching to maintain a buffer during iteration.

        Args:
            ds_input: Any dataset.

        Returns:
            A `tf.data.Dataset` with identical elements. Processing that occurs with the
            elements that are produced can be done in parallel (e.g., training on the
            GPU) while new elements are generated from the pipeline.
        """
        if self.prefetch:
            return ds_input.prefetch(buffer_size=self.buffer_size)
        else:
            return ds_input
示例#20
0
    def _prepare_dataset(self,
                         dataset: tf.data.Dataset,
                         shuffle: bool = False) -> tf.data.Dataset:
        dataset = dataset.map(
            map_func=lambda x, y: (tf.reshape(self.normalization_layer(
                tf.reshape(x, shape=(1, self.num_features)), training=False),
                                              shape=(self.num_features, )), y),
            num_parallel_calls=tf.data.experimental.AUTOTUNE)

        if shuffle:
            dataset = dataset.shuffle(buffer_size=1_000)

        dataset = dataset.batch(batch_size=self.batch_size)

        return dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
示例#21
0
 def get_data_iter(self,
                   dataset: tf.data.Dataset,
                   batch_size=32,
                   train=True) -> tf.data.Dataset:
     padded_shapes = self._padded_shapes()
     padding_values = self._padding_values()
     if train:
         dataset = dataset.shuffle(batch_size * 100)
     dataset = dataset.padded_batch(batch_size,
                                    padded_shapes=padded_shapes,
                                    padding_values=padding_values)
     dataset = dataset.map(self._collate_fn,
                           num_parallel_calls=tf.data.experimental.AUTOTUNE)
     dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
     return dataset
示例#22
0
    def prefetch(self, data: tf.data.Dataset) -> tf.data.Dataset:
        """
        Prefetch the data to memory

        Parameters
        ----------
        data
            data

        Returns
        -------
        data
            data after prefetching
        """
        data = data.prefetch(self.prefetch_buffer_size)
        return data
示例#23
0
def prepare_dataset(
    dataset: tf.data.Dataset,
    model_image_size: Tuple[int, int],
    augmentation_fn: Optional[ImageDataMapFn] = None,
    num_epochs: Optional[int] = None,
    batch_size: Optional[int] = None,
    shuffle_buffer_size: Optional[int] = None,
    num_parallel_calls: Optional[int] = None,
    prefetch_buffer_size: Optional[int] = None,
    prefetch_to_device: Optional[str] = None,
) -> tf.data.Dataset:

    # apply data augmentation:
    if augmentation_fn is not None:
        dataset = dataset.map(
            map_image_data(augmentation_fn),
            num_parallel_calls=num_parallel_calls,
        )

    if shuffle_buffer_size is not None:
        dataset = dataset.shuffle(buffer_size=shuffle_buffer_size)

    dataset = dataset.repeat(num_epochs)
    dataset = dataset.map(
        map_image_data(prepare_for_batching(model_image_size)),
        num_parallel_calls=num_parallel_calls,
    )

    # batching and padding
    if batch_size is not None:
        dataset = dataset.padded_batch(
            batch_size=batch_size,
            padded_shapes=get_padding_shapes(
                dataset, spatial_image_shape=model_image_size),
            drop_remainder=True,
        )

    # try to prefetch dataset on certain device
    if prefetch_to_device is not None:
        prefetch_fn = tf.data.experimental.prefetch_to_device(
            device=prefetch_to_device, buffer_size=prefetch_buffer_size)
        dataset = dataset.apply(prefetch_fn)
    else:
        if prefetch_buffer_size is not None:
            dataset = dataset.prefetch(buffer_size=prefetch_buffer_size)

    return dataset
示例#24
0
    def get_tfds_data_loader(data : tf.data.Dataset, data_subset_mode='train', batch_size=32, num_samples=100, num_classes=19, infinite=True, augment=True, seed=2836):


        def encode_example(x, y):
            x = tf.image.convert_image_dtype(x, tf.float32) * 255.0
            y = _encode_label(y, num_classes=num_classes)
            return x, y

        test_d = next(iter(data))
        print(test_d[0].numpy().min())
        print(test_d[0].numpy().max())

        data = data.shuffle(buffer_size=num_samples) \
                   .cache() \
                   .map(encode_example, num_parallel_calls=AUTOTUNE)

        test_d = next(iter(data))
        print(test_d[0].numpy().min())
        print(test_d[0].numpy().max())

        data = data.map(preprocess_input, num_parallel_calls=AUTOTUNE)

        test_d = next(iter(data))
        print(test_d[0].numpy().min())
        print(test_d[0].numpy().max())

        if data_subset_mode == 'train':
            data = data.shuffle(buffer_size=100, seed=seed)
            augmentor = TRAIN_image_augmentor
        elif data_subset_mode == 'val':
            augmentor = VAL_image_augmentor
        elif data_subset_mode == 'test':
            augmentor = TEST_image_augmentor

        if augment:
            data = augmentor.apply_augmentations(data)

        test_d = next(iter(data))
        print(test_d[0].numpy().min())
        print(test_d[0].numpy().max())

        data = data.batch(batch_size, drop_remainder=True)
        if infinite:
            data = data.repeat()

        return data.prefetch(AUTOTUNE)
def prepare_for_training(data_set: tf.data.Dataset,
                         batch_size,
                         cache_path=None,
                         shuffle_buffer_size=1000):
    if cache_path != '':
        cache_filename = 'dataset_train.tfcache'
        data_set = data_set.cache(''.join([cache_path, '/', cache_filename]))

    data_set = data_set.shuffle(buffer_size=shuffle_buffer_size)
    # repeat forever
    data_set = data_set.repeat()
    data_set = data_set.batch(batch_size=batch_size)
    # `prefetch` lets the dataset fetch batches in the background
    # while the model is training.
    data_set = data_set.prefetch(buffer_size=AUTOTUNE)

    return data_set
示例#26
0
def prefetch_to_available_gpu_device(dataset: tf.data.Dataset,
                                     buffer_size: int = None,
                                     use_workaround: bool = False):
    gpus = tf.config.experimental.list_physical_devices('GPU')
    if gpus:
        assert len(
            gpus) == 1, 'Expected to find exactly 1 GPU, but found: ' + gpus
        if use_workaround:
            dataset = dataset.apply(
                tf.data.experimental.copy_to_device("/GPU:0"))
            with tf.device("/GPU:0"):
                return dataset.prefetch(buffer_size)
        else:
            return dataset.apply(
                tf.data.experimental.prefetch_to_device('/GPU:0', buffer_size))
    else:
        return dataset
示例#27
0
def _pipeline_ds(dataset: tf.data.Dataset,
                 batch_size: int,
                 is_training=False):
  """Preprocess and batch dataset."""
  if is_training:
    dataset = dataset.shuffle(BUFFER_SIZE)
    dataset = dataset.repeat()
    dataset = dataset.map(_preprocess_train,
                          num_parallel_calls=tf.data.experimental.AUTOTUNE)
  else:
    dataset = dataset.map(_preprocess_eval,
                          num_parallel_calls=tf.data.experimental.AUTOTUNE)
  dataset = dataset.batch(batch_size,
                          drop_remainder=is_training)

  # Prefetch overlaps in-feed with training
  dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
  return dataset
示例#28
0
 def prepare_ds(dataset: tf.data.Dataset,
                config: HyperparameterDict) -> tf.data.Dataset:
     # Cast to float
     dataset = dataset.map(lambda x: tf.cast(x, tf.float32),
                           num_parallel_calls=tf.data.experimental.AUTOTUNE)
     dataset = dataset.map(lambda x: config['rescaling'](x),
                           num_parallel_calls=tf.data.experimental.AUTOTUNE)
     dataset = dataset.map(config['resizing'],
                           num_parallel_calls=tf.data.experimental.AUTOTUNE)
     if config['cache_data']:
         dataset.cache(
         )  # As the dataset fit in memory, cache before shuffling for better performance.
     dataset = dataset.shuffle(
         1000
     )  # For true randomness, set the shuffle buffer to the full dataset size.
     dataset = dataset.batch(config['batch_size'])
     dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
     return dataset
示例#29
0
    def _preprocess_dataset(
        self,
        dataset_processor,
        ds: tf.data.Dataset,
    ) -> tf.data.Dataset:
        ds = dataset_processor.pre_process(ds, self.PARALLEL_CALLS)

        ds = ds.batch(self._batch_size)
        #
        # if self._repeat:
        #     ds = ds.repeat(self._repeat)
        # else:
        #     ds = ds.repeat()

        if self._use_prefetch:
            ds = ds.prefetch(tf.data.experimental.AUTOTUNE)

        return ds
示例#30
0
def _wrap_dataset(
    mode: PipelineMode,
    dataset: tf.data.Dataset,
    pipeline_params: DataPipelineParams,
    data: DataBase,
    yields_batches: bool,
) -> tf.data.Dataset:
    """
    Shuffle, pad, batch, and prefetch a tf.data.Dataset
    """
    if pipeline_params.shuffle_buffer_size > 1:
        dataset = dataset.shuffle(pipeline_params.shuffle_buffer_size)
    if not yields_batches:
        dataset = _wrap_padded_batch(mode, dataset, pipeline_params, data)
    if pipeline_params.prefetch > 0:
        dataset = dataset.prefetch(pipeline_params.prefetch)
    dataset = dataset.take(compute_limit(pipeline_params.limit, pipeline_params.batch_size))
    return dataset