Exemple #1
0
    def _prepare_dataset(
        self,
        dataset: tf.data.Dataset,
        shuffle: bool = False,
        augment: bool = False
    ) -> tf.data.Dataset:
        preprocessing_model = self._build_preprocessing()
        dataset = dataset.map(
            map_func=lambda x, y: (preprocessing_model(x, training=False), y),
            num_parallel_calls=tf.data.experimental.AUTOTUNE
        )

        if shuffle:
            dataset = dataset.shuffle(buffer_size=1_000)

        dataset = dataset.batch(batch_size=self.batch_size)

        if augment:
            data_augmentation_model = self._build_data_augmentation()
            dataset = dataset.map(
                map_func=lambda x, y: (data_augmentation_model(x, training=False), y),
                num_parallel_calls=tf.data.experimental.AUTOTUNE
            )

        return dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
Exemple #2
0
    def window_dataset_for_zipped_example_and_label_dataset(
            dataset: tf.data.Dataset, batch_size: int,
            window_shift: int) -> tf.data.Dataset:
        """
        Takes a zipped example and label dataset, and converts it to batches, where each batch uses overlapping
        examples based on a sliding window.

        :param dataset: The zipped example and label dataset.
        :param batch_size: The size of the batches to produce.
        :param window_shift: The shift of the moving window between batches.
        :return: The window dataset.
        """
        examples_dataset = dataset.map(lambda element, _: element)
        labels_dataset = dataset.map(lambda _, element: element)
        examples_window_dataset = examples_dataset.window(batch_size,
                                                          shift=window_shift)
        examples_unbatched_window_dataset = examples_window_dataset.flat_map(
            lambda element: element)
        labels_window_dataset = labels_dataset.window(batch_size,
                                                      shift=window_shift)
        labels_unbatched_window_dataset = labels_window_dataset.flat_map(
            lambda element: element)
        unbatched_window_dataset = tf.data.Dataset.zip(
            (examples_unbatched_window_dataset,
             labels_unbatched_window_dataset))
        return unbatched_window_dataset.batch(batch_size)
 def pre_process(self, dataset: tf.data.Dataset, parallel_calls: int):
     dataset = dataset.map(
         self._extract_specified_modalities_and_ensure_shape,
         num_parallel_calls=parallel_calls)
     dataset = dataset.map(self.concat_with_labels,
                           num_parallel_calls=parallel_calls)
     return dataset
Exemple #4
0
def preprocess_mlp_text(dataset: tf.data.Dataset,
                        parameter: Dict) -> np.ndarray:
    """
    Perform tokenization for MLP model

    :param tf.data.Dataset dataset: dataset containing text and label data
    :param Dict parameter: parameter object containing vocab_size and sequence_lenght parameter
    :return: tokenized and padded text
    :rtype: np.ndarray
    """
    vectorize_layer = tf.keras.layers.experimental.preprocessing.TextVectorization(
        standardize=custom_standardization,
        max_tokens=parameter["vocab_size"],
        output_mode='int',
        output_sequence_length=parameter["sequence_length"])

    def vectorize_text_func(text, label):
        text = tf.expand_dims(text, -1)
        return vectorize_layer(text), label

    text_ds = dataset.map(lambda x, y: x)
    vectorize_layer.adapt(text_ds)
    text_ds = dataset.map(vectorize_text_func)

    tokenized_texts = []
    for text, _ in text_ds:
        tokenized_texts.append(text.numpy())

    return np.vstack(tokenized_texts)
Exemple #5
0
Fichier : vis.py Projet : jackd/ecn
def _get_cached_adjacency_dataset(
    meta_model_func: Callable,
    dataset: tf.data.Dataset,
    augment_func: Optional[Callable] = None,
):
    if augment_func is not None:
        dataset = dataset.map(augment_func)
    batcher = tf.data.experimental.dense_to_ragged_batch(batch_size=1)
    builder = pl.PipelinedModelBuilder(dataset.element_spec, batcher=batcher)
    with builder:
        with comp.stream_accumulator() as streams:
            with comp.convolver_accumulator() as convolvers:
                inputs = builder.pre_cache_inputs
                logging.info("Building multi graph...")
                meta_model_func(*inputs)
                logging.info("Successfully built!")
        stream_indices = {s: i for i, s in enumerate(streams)}
        outputs = []
        decay_times = []
        stream_indices = tuple(
            (stream_indices[c.in_stream], stream_indices[c.out_stream])
            for c in convolvers
        )
        for c in convolvers:
            outputs.append((c.indices, c.splits, c.in_stream.times, c.out_stream.times))
            decay_times.append(c.decay_time)

    dataset = dataset.map(ModelMap(builder.build_pre_cache_model(outputs)))
    return dataset, decay_times, stream_indices, len(convolvers)
Exemple #6
0
def prepare_Dataset(dataset: tf.data.Dataset,
                    shuffle: bool = False,
                    augment: bool = False) -> tf.data.Dataset:
    """Prepare the dataset object with preprocessing and data augmentation.

    Parameters
    ----------
    dataset : tf.data.Dataset
        The dataset object
    shuffle : bool, optional
        Whether to shuffle the dataset, by default False
    augment : bool, optional
        Whether to augment the train dataset, by default False

    Returns
    -------
    tf.data.Dataset
        The prepared dataset
    """
    preprocessing_model = build_preprocessing()
    dataset = dataset.map(map_func=lambda x, y: (preprocessing_model(x), y),
                          num_parallel_calls=AUTOTUNE)

    if shuffle:
        dataset = dataset.shuffle(buffer_size=1_000)

    dataset = dataset.batch(batch_size=BATCH_SIZE)

    if augment:
        data_augmentation_model = build_data_augmentation()
        dataset = dataset.map(map_func=lambda x, y:
                              (data_augmentation_model(x), y),
                              num_parallel_calls=AUTOTUNE)

    return dataset.prefetch(buffer_size=AUTOTUNE)
Exemple #7
0
def batch_dataset(dataset: tf.data.Dataset,
                  model: LineRecognizer,
                  batch_size=32,
                  bucket_boundaries=None,
                  padded=True):
    # add image widths and text length
    dataset = dataset.map(lambda i, t: (i, tf.shape(i)[
        1], t, tf.strings.length(t, unit='UTF8_CHAR')))

    dataset = dataset.map(lambda image, width, text, length:
                          (image, width, model.encoder.encode(text), length))

    output_shapes = (model.image_shape, [], [None], [])

    if bucket_boundaries:
        if isinstance(batch_size, int):
            batch_size = [batch_size] * (len(bucket_boundaries) + 1)

        dataset = dataset.apply(
            tf.data.experimental.bucket_by_sequence_length(
                lambda i, w, label, length: w,
                bucket_boundaries=bucket_boundaries,
                bucket_batch_sizes=batch_size,
                padded_shapes=output_shapes))

    elif padded:
        dataset = dataset.padded_batch(batch_size=batch_size,
                                       padded_shapes=output_shapes)
    else:
        dataset = dataset.batch(batch_size)

    return dataset
def _prepare(config: dict, dataset: tf.data.Dataset) -> tf.data.Dataset:
    dataset = dataset.map(lambda x, y: _do_resize(x, y, config['img_size']),
                          num_parallel_calls=config['n_threads'])
    if config['augment']:
        dataset = dataset.map(_do_augment,
                              num_parallel_calls=config['n_threads'])
    dataset = dataset.map(_normalize, num_parallel_calls=config['n_threads'])
    return dataset
Exemple #9
0
  def pipeline(self, dataset: tf.data.Dataset) -> tf.data.Dataset:
    """Build a pipeline fetching, shuffling, and preprocessing the dataset.

    Args:
      dataset: A `tf.data.Dataset` that loads raw files.

    Returns:
      A TensorFlow dataset outputting batched images and labels.
    """
    if self._num_gpus > 1:
      dataset = dataset.shard(self._num_gpus, hvd.rank())

    if self.is_training:
      # Shuffle the input files.
      dataset.shuffle(buffer_size=self._file_shuffle_buffer_size)

    if self.is_training and not self._cache:
      dataset = dataset.repeat()

    # Read the data from disk in parallel
    dataset = dataset.interleave(
        tf.data.TFRecordDataset,
        cycle_length=10,
        block_length=1,
        num_parallel_calls=tf.data.experimental.AUTOTUNE)

    if self._cache:
      dataset = dataset.cache()

    if self.is_training:
      dataset = dataset.shuffle(self._shuffle_buffer_size)
      dataset = dataset.repeat()

    # Parse, pre-process, and batch the data in parallel
    preprocess = self.parse_record
    dataset = dataset.map(preprocess,
                          num_parallel_calls=tf.data.experimental.AUTOTUNE)

    if self._num_gpus > 1:
      # The batch size of the dataset will be multiplied by the number of
      # replicas automatically when strategy.distribute_datasets_from_function
      # is called, so we use local batch size here.
      dataset = dataset.batch(self.local_batch_size,
                              drop_remainder=self.is_training)
    else:
      dataset = dataset.batch(self.global_batch_size,
                              drop_remainder=self.is_training)

    # Apply Mixup
    mixup_alpha = self.mixup_alpha if self.is_training else 0.0
    dataset = dataset.map(
        functools.partial(self.mixup, self.local_batch_size, mixup_alpha),
        num_parallel_calls=64)

    # Prefetch overlaps in-feed with training
    dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)

    return dataset
Exemple #10
0
def processing(dataset: tf.data.Dataset, window_size, batch_size):
    dataset = dataset.map(lambda x: table.lookup(x))
    dataset = dataset.unbatch()
    dataset = dataset.window(window_size+1, shift = 1, drop_remainder=True)
    dataset = dataset.flat_map(lambda ds: ds.batch(window_size+1))
    dataset = dataset.map(lambda x: (x[:-1], x[-1]-1))
    dataset = dataset.shuffle(10000)
    dataset = dataset.batch(batch_size).prefetch(1)
    return dataset
def accumulated_batch(
    dataset: tf.data.Dataset,
    accumulator: Union[Accumulator, Mapping, Iterable],
    **map_kwargs,
):
    accumulator = accumulator_structure(accumulator)

    def initial_map_fn(*args):
        if len(args) == 1:
            (args,) = args
        return args, False

    @tf.function
    def scan_fn(state, el_and_final):
        el, final = el_and_final
        new_state = accumulator.append(state, el)
        valid = tf.reduce_all(accumulator.valid_conditions(new_state))
        invalid = tf.logical_not(valid)
        if invalid:
            new_state = accumulator.append(accumulator.initial_state(), el)
        return new_state, (state, tf.logical_or(invalid, final))

    def filter_fn(state, invalid):
        del state
        return invalid

    def map_fn(state, invalid):
        del invalid
        return accumulator.finalize(state)

    cardinality = tf.data.experimental.cardinality(dataset)

    dataset = dataset.map(initial_map_fn)
    if cardinality != tf.data.experimental.INFINITE_CARDINALITY:
        # append (empty, True) element to ensure final elements are generated
        state_spec = dataset.element_spec[0]
        empty_el = tf.nest.map_structure(
            lambda spec: tf.zeros(
                [1, *(0 if s is None else s for s in spec.shape)], dtype=spec.dtype
            ),
            state_spec,
        )
        true_el = tf.ones((1,), dtype=tf.bool)
        dataset = dataset.concatenate(
            tf.data.Dataset.from_tensor_slices((empty_el, true_el))
        )

    dataset = dataset.apply(
        tf.data.experimental.scan(accumulator.initial_state(), scan_fn)
    )

    dataset = dataset.filter(filter_fn)
    dataset = dataset.map(map_fn, **map_kwargs)
    return dataset
Exemple #12
0
 def transform_and_filter(self, ds: tf.data.Dataset):
     """
     Map from example into viable fit input
     """
     ds = ds.map(
         self.decode_record, num_parallel_calls=tf.data.experimental.AUTOTUNE
     )
     ds = ds.map(
         self.select_data_from_record,
         num_parallel_calls=tf.data.experimental.AUTOTUNE,
     )
     return ds
Exemple #13
0
    def change_image_size(
        self, train: tf.data.Dataset, validation: tf.data.Dataset,
        test: tf.data.Dataset
    ) -> (tf.data.Dataset, tf.data.Dataset, tf.data.Dataset):
        size = int(self.options['Size'])

        def resize(x, y):
            x = tf.image.resize(x, size, size)
            return x, y

        train = train.map(map_func=resize)
        validation = validation.map(map_func=resize)
        test = test.map(map_func=resize)
        return (train, validation, test)
Exemple #14
0
def _concatenate(in_ds: tf.data.Dataset,
                 out_ds: tf.data.Dataset) -> tf.data.Dataset:
    """Concatenate in_ds and out_ds, making sure they have compatible specs."""
    in_spec = in_ds.element_spec
    out_spec = out_ds.element_spec

    def format_in_ds(feature):
        feature = _set_label_to_one(feature)
        return _keep_common_fields(feature, out_spec)

    def format_out_ds(feature):
        feature = _set_label_to_zero(feature)
        return _keep_common_fields(feature, in_spec)

    return in_ds.map(format_in_ds).concatenate(out_ds.map(format_out_ds))
Exemple #15
0
    def transform_dataset(self, input_ds: tf.data.Dataset) -> tf.data.Dataset:
        """Create a dataset that contains instance cropped data."""

        def norm_instance(example):
            """Local processing function for dataset mapping."""
            centroids = example[self.centroid_key] / example["scale"]

            bboxes = example["bbox"]
            bboxes = expand_to_rank(bboxes, 2)
            bboxes_x1y1 = tf.gather(bboxes, [1, 0], axis=1)

            pts = example[self.peaks_key]
            pts += bboxes_x1y1
            pts /= example["scale"]

            example[self.new_centroid_key] = centroids
            example[self.new_centroid_confidence_key] = example[
                self.centroid_confidence_key
            ]
            example[self.new_peaks_key] = pts
            example[self.new_peak_confidences_key] = example[self.peak_confidences_key]
            return example

        # Map the main processing function to each example.
        output_ds = input_ds.map(
            norm_instance, num_parallel_calls=tf.data.experimental.AUTOTUNE
        )

        return output_ds
Exemple #16
0
    def transform_dataset(self, input_ds: tf.data.Dataset) -> tf.data.Dataset:
        device_name = self.device_name
        if device_name is None:
            device_name = best_logical_device_name()

        def predict(example):
            with tf.device(device_name):
                X = []
                for input_key in self.model_input_keys:
                    input_rank = tf.rank(example[input_key])
                    X.append(
                        expand_to_rank(example[input_key], target_rank=4, prepend=True)
                    )

                Y = self.keras_model(X)
                if not isinstance(Y, list):
                    Y = [Y]

                for output_key, y in zip(self.model_output_keys, Y):
                    if isinstance(y, list):
                        y = y[0]
                    if input_rank < tf.rank(y):
                        y = tf.squeeze(y, axis=0)
                    example[output_key] = y

                return example

        output_ds = input_ds.map(
            predict, num_parallel_calls=tf.data.experimental.AUTOTUNE
        )
        return output_ds
Exemple #17
0
    def transform_dataset(self, input_ds: tf.data.Dataset) -> tf.data.Dataset:
        def find_peaks(example):
            # Match example centroid to the instance with the closest node.
            centroid = example["centroid"] / example["scale"]
            all_peaks = example[self.all_peaks_in_key]  # (n_instances, n_nodes, 2)
            dists = tf.reduce_min(
                tf.norm(
                    all_peaks - tf.reshape(centroid, [1, 1, 2]), axis=-1
                ),
                axis=1,
            )  # (n_instances,)
            instance_ind = tf.argmin(dists)
            center_instance = tf.gather(all_peaks, instance_ind)

            # Adjust to coordinates relative to bounding box.
            center_instance -= tf.reshape(tf.gather(example["bbox"], [1, 0]), [1, 2])

            # Fill in mock data.
            example[self.peaks_out_key] = center_instance
            example[self.peak_vals_key] = tf.ones(
                [tf.shape(center_instance)[0]], dtype=tf.float32
            )
            example.pop(self.all_peaks_in_key)

            if self.keep_confmaps:
                example[self.confmaps_out_key] = example[self.confmaps_in_key]
                example.pop(self.confmaps_in_key)

            return example

        output_ds = input_ds.map(
            find_peaks, num_parallel_calls=tf.data.experimental.AUTOTUNE
        )
        return output_ds
 def add_field_names_preprocessor(
     dataset: tf.data.Dataset, ) -> tf.data.Dataset:
     return dataset.map(
         lambda *row: {
             field_name: row[field_index]
             for field_name, field_index in zip(field_names, field_indices)
         })
Exemple #19
0
    def preprocess(self, dataset: tf.data.Dataset) -> tf.data.Dataset:
        """Apply preprocessing specifitly for this model.

        Extract the features from the image with the encoder.
        Flatten and concatenate them with the clearsky.
        Change target to only consider present time.
        Data is now (features, target).
        """
        def encoder(image):
            # Create Fake Batch Size
            image = tf.expand_dims(image, 0)
            image_encoded = self.encoder((image), training=False)
            # Remove Fake Batch Size
            return self.flatten(image_encoded)[0, :]

        def preprocess(image, clearsky, target_ghi):
            # Normalize inputs
            image = self.scaling_image.normalize(image)
            clearsky = self.scaling_ghi.normalize(clearsky)
            target_ghi = self.scaling_ghi.normalize(target_ghi)

            image_features = tf.py_function(func=encoder,
                                            inp=[image],
                                            Tout=tf.float32)
            clearsky = self._preprocess_target(clearsky)
            target_ghi = self._preprocess_target(target_ghi)

            features = tf.concat([image_features, clearsky], 0)
            return (features, target_ghi)

        return dataset.map(preprocess).cache()
Exemple #20
0
def preprocessing(dsData: tf.data.Dataset, window_size, batch_size):
    dsData = dsData.window(window_size + 1, shift=1, drop_remainder=True)
    dsData = dsData.flat_map(lambda w: w.batch(window_size + 1))
    dsData = dsData.map(lambda x: (x[:-1], x[-1]))
    dsData = dsData.shuffle(1000)
    dsData = dsData.batch(batch_size).prefetch(1)
    return dsData
Exemple #21
0
def mixup(
    ds: tf.data.Dataset,
    postmix_fn: typing.Callable[..., typing.Any] = None,
    num_parallel_calls: int = None,
):
    """tf.dataでのmixup: <https://arxiv.org/abs/1710.09412>

    Args:
        ds: 元のデータセット
        postmix_fn: mixup後の処理
        num_parallel_calls: premix_fnの並列数

    """
    @tf.function
    def mixup_fn(*data):
        r = _tf_random_beta(alpha=0.2, beta=0.2)
        data = [
            tf.cast(d[0], tf.float32) * r + tf.cast(d[1], tf.float32) * (1 - r)
            for d in data
        ]
        return data if postmix_fn is None else postmix_fn(*data)

    ds = ds.repeat()
    ds = ds.batch(2)
    ds = ds.map(
        mixup_fn,
        num_parallel_calls=num_parallel_calls,
        deterministic=None if num_parallel_calls is None else False,
    )
    return ds
Exemple #22
0
    def transform_dataset(self, input_ds: tf.data.Dataset) -> tf.data.Dataset:
        """Create a dataset that contains instance cropped data."""
        def rescale_points(example):
            """Local processing function for dataset mapping."""
            # Pull out data.
            points = example[self.points_key]
            scale = example[self.scale_key]

            # Make sure the scale lines up with the last dimension of the points.
            scale = expand_to_rank(scale, tf.rank(points))

            # Scale.
            if self.invert:
                points /= scale
            else:
                points *= scale

            # Update example.
            example[self.points_key] = points
            return example

        # Map the main processing function to each example.
        output_ds = input_ds.map(
            rescale_points, num_parallel_calls=tf.data.experimental.AUTOTUNE)

        return output_ds
Exemple #23
0
    def preprocessing(self, dataset: tf.data.Dataset) -> tf.data.Dataset:
        """Proprocess dataset to have ((encoder_input, decoder_input), target)."""
        # Teacher forcing.
        def preprocess(input_sentence, output_sentence):
            return ((input_sentence, output_sentence[:-1]), output_sentence[1:])

        return dataset.map(preprocess)
Exemple #24
0
    def prepare_dataset(self,
                        dataset: tf.data.Dataset,
                        buckets: List[int],
                        batch_sizes: List[int],
                        shuffle: bool = False) -> tf.data.Dataset:
        dataset = dataset.map(self._deserialization_func,
                              num_parallel_calls=128)

        buckets_array = np.array(buckets)
        batch_sizes_array = np.array(batch_sizes)

        if np.any(batch_sizes_array == 0) and shuffle:
            iszero = np.where(batch_sizes_array == 0)[0][0]
            filterlen = buckets_array[iszero - 1]
            print("Filtering sequences of length {}".format(filterlen))
            dataset = dataset.filter(
                lambda example: example['protein_length'] < filterlen)
        else:
            batch_sizes_array[batch_sizes_array <= 0] = 1

        dataset = dataset.shuffle(1024) if shuffle else dataset.prefetch(1024)
        batch_fun = tf.data.experimental.bucket_by_sequence_length(
            operator.itemgetter('protein_length'), buckets_array,
            batch_sizes_array)
        dataset = dataset.apply(batch_fun)
        return dataset
Exemple #25
0
def append_conditional_augmentation(
        dataset: tf.data.Dataset,
        augmentations: list,
        accept_probability: float = 0.1) -> tf.data.Dataset:
    """Append augmentations to a TF dataset, each with a probability of
    acceptance.

    Parameters
    ----------
    dataset : tf.data.Dataset
        A tensorflow dataset to be augmented.
    augmentations : list of functions
        A list of functions which accept a tf.Tensor and augment it.
    accept_probability : float
        A float representing the probability of performing the augmentation.

    Returns
    -------
    dataset : tf.data.Dataset
        The augmented dataset.
    """

    for augmentation in augmentations:
        dataset = dataset.map(lambda *x: tf.cond(
            tf.random.uniform([], 0.0, 1.0) < accept_probability,
            lambda: augmentation(*x),
            lambda: x,
        ))
    return dataset
Exemple #26
0
    def __call__(self, dataset: tf.data.Dataset) -> tf.data.Dataset:
        specs = dataset.element_spec

        def pre_batch_map_func(*args):
            if len(args) == 1:
                (args, ) = args
            sizes = tf.nest.map_structure(
                lambda x: x.nrows()
                if isinstance(x, tf.RaggedTensor) else tf.shape(x)[0]
                if x.shape.ndims > 0 else tf.zeros((), dtype=tf.int32),
                args,
            )
            return args, sizes

        def post_batch_map_func(args, sizes):
            return tf.nest.map_structure(
                lambda arg, size, spec: tf.RaggedTensor.from_tensor(arg, size)
                if spec.shape[0] is None and isinstance(spec, tf.TensorSpec)
                else arg,
                args,
                sizes,
                specs,
            )

        return (dataset.map(pre_batch_map_func).padded_batch(
            self._batch_size,
            drop_remainder=self._drop_remainder).map(post_batch_map_func))
Exemple #27
0
 def preprocessing_fn(dataset: tf.data.Dataset) -> tf.data.Dataset:
   dataset = dataset.map(
       functools.partial(_map_fn, is_training=True, image_size=image_size),
       num_parallel_calls=tf.data.experimental.AUTOTUNE).shuffle(
           shuffle_buffer_size).take(max_elements).repeat(num_epochs).batch(
               batch_size)
   return dataset
Exemple #28
0
def to_rgb_32_x_32(D: tf.data.Dataset) -> tf.data.Dataset:
    def f_map(X, y):
        X = tf.image.resize(X, (32, 32))
        X = tf.image.grayscale_to_rgb(X)
        return X, y

    return D.map(f_map)
Exemple #29
0
 def deserialize_images(dataset: tf.data.Dataset):
     """Convert dataset['image'] back into numpy array
     
     Arguments:
         dataset {tf.data.Dataset} -- dataset to process
     """
     return dataset.map(raw_to_numpy)
Exemple #30
0
    def get_bucket_iter(self,
                        dataset: tf.data.Dataset,
                        batch_size=32,
                        train=True) -> tf.data.Dataset:
        padded_shapes = self._padded_shapes()
        padding_values = self._padding_values()
        if train:
            bucket_boundaries = self._bucket_boundaries(batch_size)
            bucket_batch_sizes = [batch_size] * (len(bucket_boundaries) + 1)
            dataset = dataset.apply(
                tf.data.experimental.bucket_by_sequence_length(
                    self.element_length_func,
                    bucket_boundaries,
                    bucket_batch_sizes,
                    padded_shapes=padded_shapes,
                    padding_values=padding_values))
            dataset = dataset.shuffle(100)
        else:
            dataset = dataset.padded_batch(batch_size,
                                           padded_shapes=padded_shapes)

        dataset = dataset.map(self._collate_fn,
                              num_parallel_calls=tf.data.experimental.AUTOTUNE)
        dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
        return dataset