Example #1
0
 def _maybe_apply_data_service(
     self,
     dataset: tf.data.Dataset,
     input_context: Optional[tf.distribute.InputContext] = None
 ) -> tf.data.Dataset:
     """Potentially distributes a dataset."""
     if self._enable_tf_data_service and input_context:
         if self._enable_round_robin_tf_data_service:
             replicas_per_input_pipeline = input_context.num_replicas_in_sync // (
                 input_context.num_input_pipelines)
             base_consumer_index = input_context.input_pipeline_id * (
                 replicas_per_input_pipeline)
             num_consumers = input_context.num_input_pipelines * (
                 replicas_per_input_pipeline)
             range_dataset = tf.data.Dataset.range(
                 replicas_per_input_pipeline)
             tfds_kwargs = {
                 'processing_mode': 'parallel_epochs',
                 'service': self._tf_data_service_address,
                 'job_name': self._tf_data_service_job_name,
                 'num_consumers': num_consumers
             }
             if self._enable_shared_tf_data_service_between_parallel_trainers:
                 raise ValueError(
                     'Shared tf.data service does not support round-robin'
                     ' tf.data service.')
             dataset = range_dataset.map(lambda i: dataset.apply(  # pylint: disable=g-long-lambda
                 tf.data.experimental.service.
                 distribute(consumer_index=base_consumer_index + i,
                            **tfds_kwargs)))
             # Use parallel interleave to read multiple batches from a tf.data
             # service worker in parallel.
             dataset = dataset.interleave(
                 lambda x: x,
                 cycle_length=replicas_per_input_pipeline,
                 num_parallel_calls=replicas_per_input_pipeline,
                 deterministic=True)
         else:
             tfds_kwargs = {
                 'processing_mode': 'parallel_epochs',
                 'service': self._tf_data_service_address,
                 'job_name': self._tf_data_service_job_name,
             }
             if self._enable_shared_tf_data_service_between_parallel_trainers:
                 tfds_kwargs.update({
                     'processing_mode':
                     tf.data.experimental.service.ShardingPolicy.OFF,
                     'cross_trainer_cache':
                     tf.data.experimental.service.CrossTrainerCache(
                         trainer_id=self._trainer_id)
                 })
             dataset = dataset.apply(
                 tf.data.experimental.service.distribute(**tfds_kwargs))
     return dataset
Example #2
0
def apply(
    dataset: tf.data.Dataset,
    transform: Union[Transform, Iterable[Optional[Transform]]],
) -> tf.data.Dataset:
    if callable(transform):
        return dataset.apply(transform)

    for tr in transform:
        if tr is not None:
            dataset = dataset.apply(tr)
    return dataset
Example #3
0
    def get_bucket_iter(self,
                        dataset: tf.data.Dataset,
                        batch_size=32,
                        train=True) -> tf.data.Dataset:
        padded_shapes = self._padded_shapes()
        padding_values = self._padding_values()
        if train:
            bucket_boundaries = self._bucket_boundaries(batch_size)
            bucket_batch_sizes = [batch_size] * (len(bucket_boundaries) + 1)
            dataset = dataset.apply(
                tf.data.experimental.bucket_by_sequence_length(
                    self.element_length_func,
                    bucket_boundaries,
                    bucket_batch_sizes,
                    padded_shapes=padded_shapes,
                    padding_values=padding_values))
            dataset = dataset.shuffle(100)
        else:
            dataset = dataset.padded_batch(batch_size,
                                           padded_shapes=padded_shapes)

        dataset = dataset.map(self._collate_fn,
                              num_parallel_calls=tf.data.experimental.AUTOTUNE)
        dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
        return dataset
Example #4
0
def get_augmented_data(
    dataset: tf.data.Dataset,
    batch_size: int,
    map_func: Callable,
    shuffle_buffer: Optional[int] = None,
    shuffle_seed: Optional[int] = None,
    augment_seed: Optional[int] = None,
    use_stateless_map: bool = False,
) -> RepeatedData:
    if shuffle_buffer is not None:
        dataset = dataset.shuffle(shuffle_buffer, seed=shuffle_seed)
    dataset = dataset.batch(batch_size)
    steps_per_epoch = tf.keras.backend.get_value(dataset.cardinality())
    # repeat before map so stateless map is different across epochs
    dataset = dataset.repeat()
    AUTOTUNE = tf.data.experimental.AUTOTUNE
    if use_stateless_map:
        dataset = dataset.apply(
            tfrng.data.stateless_map(
                map_func,
                seed=augment_seed,
                num_parallel_calls=AUTOTUNE,
            ))
    else:
        # if map_func has random elements this won't be deterministic
        dataset = dataset.map(map_func, num_parallel_calls=AUTOTUNE)
    dataset = dataset.prefetch(AUTOTUNE)
    return RepeatedData(dataset, steps_per_epoch)
Example #5
0
    def prepare_dataset(self,
                        dataset: tf.data.Dataset,
                        buckets: List[int],
                        batch_sizes: List[int],
                        shuffle: bool = False) -> tf.data.Dataset:
        dataset = dataset.map(self._deserialization_func,
                              num_parallel_calls=128)

        buckets_array = np.array(buckets)
        batch_sizes_array = np.array(batch_sizes)

        if np.any(batch_sizes_array == 0) and shuffle:
            iszero = np.where(batch_sizes_array == 0)[0][0]
            filterlen = buckets_array[iszero - 1]
            print("Filtering sequences of length {}".format(filterlen))
            dataset = dataset.filter(
                lambda example: example['protein_length'] < filterlen)
        else:
            batch_sizes_array[batch_sizes_array <= 0] = 1

        dataset = dataset.shuffle(1024) if shuffle else dataset.prefetch(1024)
        batch_fun = tf.data.experimental.bucket_by_sequence_length(
            operator.itemgetter('protein_length'), buckets_array,
            batch_sizes_array)
        dataset = dataset.apply(batch_fun)
        return dataset
Example #6
0
def batch_dataset(dataset: tf.data.Dataset,
                  model: LineRecognizer,
                  batch_size=32,
                  bucket_boundaries=None,
                  padded=True):
    # add image widths and text length
    dataset = dataset.map(lambda i, t: (i, tf.shape(i)[
        1], t, tf.strings.length(t, unit='UTF8_CHAR')))

    dataset = dataset.map(lambda image, width, text, length:
                          (image, width, model.encoder.encode(text), length))

    output_shapes = (model.image_shape, [], [None], [])

    if bucket_boundaries:
        if isinstance(batch_size, int):
            batch_size = [batch_size] * (len(bucket_boundaries) + 1)

        dataset = dataset.apply(
            tf.data.experimental.bucket_by_sequence_length(
                lambda i, w, label, length: w,
                bucket_boundaries=bucket_boundaries,
                bucket_batch_sizes=batch_size,
                padded_shapes=output_shapes))

    elif padded:
        dataset = dataset.padded_batch(batch_size=batch_size,
                                       padded_shapes=output_shapes)
    else:
        dataset = dataset.batch(batch_size)

    return dataset
Example #7
0
def bucket_batch_dataset(dataset: tf.data.Dataset, batches: List[int],
                         boundaries: List[int]) -> tf.data.Dataset:
    op = tf.contrib.data.bucket_by_sequence_length(
        element_length_func=lambda x, y: tf.shape(x)[0],
        bucket_batch_sizes=batches,
        bucket_boundaries=boundaries)
    return dataset.apply(op)
Example #8
0
def block_diagonal_batch_with_batch_size(dataset: tf.data.Dataset,
                                         batch_size: int):
    """
    Batch the input dataset block diagonally up to the given batch size.

    Args:
        dataset: tf.data.Dataset with spec ((nodes, (link*)), labels).
            nodes: [V?, ...] node features.
            link: [E?, 2] int edge/link indices.
            labels: [V?, ...] or [...] label data.
        batch_size: number of examples in the resulting batch.

    Returns:
        dataset with spec:
            nodes: [B, V?, ...] ragged node features.
            links: [E, 2] indices into flattened nodes.
            labels: [BV, ...] or [B, ...]
        B = batch_size
        BV = sum_b V_b
    """
    dataset = dataset.apply(
        tf.data.experimental.dense_to_ragged_batch(batch_size,
                                                   row_splits_dtype=tf.int32))
    return dataset.map(
        lambda *args: _block_diagonalize_batched(*_unpack_dataset(*args)))
Example #9
0
def prefetch_to_available_gpu_device(dataset: tf.data.Dataset,
                                     buffer_size: int = None,
                                     use_workaround: bool = False):
    gpus = tf.config.experimental.list_physical_devices('GPU')
    if gpus:
        assert len(
            gpus) == 1, 'Expected to find exactly 1 GPU, but found: ' + gpus
        if use_workaround:
            dataset = dataset.apply(
                tf.data.experimental.copy_to_device("/GPU:0"))
            with tf.device("/GPU:0"):
                return dataset.prefetch(buffer_size)
        else:
            return dataset.apply(
                tf.data.experimental.prefetch_to_device('/GPU:0', buffer_size))
    else:
        return dataset
Example #10
0
def accumulated_batch(
    dataset: tf.data.Dataset,
    accumulator: Union[Accumulator, Mapping, Iterable],
    **map_kwargs,
):
    accumulator = accumulator_structure(accumulator)

    def initial_map_fn(*args):
        if len(args) == 1:
            (args,) = args
        return args, False

    @tf.function
    def scan_fn(state, el_and_final):
        el, final = el_and_final
        new_state = accumulator.append(state, el)
        valid = tf.reduce_all(accumulator.valid_conditions(new_state))
        invalid = tf.logical_not(valid)
        if invalid:
            new_state = accumulator.append(accumulator.initial_state(), el)
        return new_state, (state, tf.logical_or(invalid, final))

    def filter_fn(state, invalid):
        del state
        return invalid

    def map_fn(state, invalid):
        del invalid
        return accumulator.finalize(state)

    cardinality = tf.data.experimental.cardinality(dataset)

    dataset = dataset.map(initial_map_fn)
    if cardinality != tf.data.experimental.INFINITE_CARDINALITY:
        # append (empty, True) element to ensure final elements are generated
        state_spec = dataset.element_spec[0]
        empty_el = tf.nest.map_structure(
            lambda spec: tf.zeros(
                [1, *(0 if s is None else s for s in spec.shape)], dtype=spec.dtype
            ),
            state_spec,
        )
        true_el = tf.ones((1,), dtype=tf.bool)
        dataset = dataset.concatenate(
            tf.data.Dataset.from_tensor_slices((empty_el, true_el))
        )

    dataset = dataset.apply(
        tf.data.experimental.scan(accumulator.initial_state(), scan_fn)
    )

    dataset = dataset.filter(filter_fn)
    dataset = dataset.map(map_fn, **map_kwargs)
    return dataset
Example #11
0
    def batch(self, dataset: tf.data.Dataset) -> tf.data.Dataset:
        bounds = list(range(self.hist_min, self.hist_max, self.hist_step))

        logging.info("Quantile bucketing from %d-%d with %d buckets" %
                     (bounds[0], bounds[-1], len(bounds)))

        return dataset.apply(
            ops.bucket_by_quantiles(
                len_fn=lambda x: tf.shape(x[PREMISE_KEY])[0],
                batch_size=self.batch_size,
                n_buckets=self.n_buckets,
                hist_bounds=bounds))
Example #12
0
 def prepare_dataset(self,
                     dataset: tf.data.Dataset,
                     buckets: List[int],
                     batch_sizes: List[int],
                     shuffle: bool = False) -> tf.data.Dataset:
     dataset = dataset.map(self._deserialization_func, 128)
     dataset = dataset.shuffle(1024) if shuffle else dataset.prefetch(1024)
     batch_fun = tf.data.experimental.bucket_by_sequence_length(
         lambda example: tf.maximum(example['first']['protein_length'],
                                    example['second']['protein_length']),
         buckets, batch_sizes)
     dataset = dataset.apply(batch_fun)
     return dataset
Example #13
0
 def _maybe_apply_data_service(
     self,
     dataset: tf.data.Dataset,
     input_context: Optional[tf.distribute.InputContext] = None
 ) -> tf.data.Dataset:
     """Potentially distributes a dataset."""
     if self._enable_tf_data_service and input_context:
         if self._enable_round_robin_tf_data_service:
             replicas_per_input_pipeline = input_context.num_replicas_in_sync // (
                 input_context.num_input_pipelines)
             base_consumer_index = input_context.input_pipeline_id * (
                 replicas_per_input_pipeline)
             num_consumers = input_context.num_input_pipelines * (
                 replicas_per_input_pipeline)
             range_dataset = tf.data.Dataset.range(
                 replicas_per_input_pipeline)
             dataset = range_dataset.map(lambda i: dataset.apply(  # pylint: disable=g-long-lambda
                 tf.data.experimental.service.distribute(
                     processing_mode='parallel_epochs',
                     service=self._tf_data_service_address,
                     job_name=self._tf_data_service_job_name,
                     consumer_index=base_consumer_index + i,
                     num_consumers=num_consumers)))
             # Use parallel interleave to read multiple batches from a tf.data
             # service worker in parallel.
             dataset = dataset.interleave(
                 lambda x: x,
                 cycle_length=replicas_per_input_pipeline,
                 num_parallel_calls=replicas_per_input_pipeline,
                 deterministic=True)
         else:
             dataset = dataset.apply(
                 tf.data.experimental.service.distribute(
                     processing_mode='parallel_epochs',
                     service=self._tf_data_service_address,
                     job_name=self._tf_data_service_job_name))
     return dataset
Example #14
0
 def reduce_func(unused_key, window: tf.data.Dataset):
     # ToDo: use padded_batch instead of padded_batch_and_drop_remainder
     # Currently this model only works with static batch size
     apply_fn = tf.contrib.data.padded_batch_and_drop_remainder(
         batch_size,
         padded_shapes=(PreparedSourceData(
             id=tf.TensorShape([]),
             text=tf.TensorShape([]),
             source=tf.TensorShape([None]),
             source_length=tf.TensorShape([]),
             text_positions=tf.TensorShape([None]),
             text2=tf.TensorShape([]),
             source2=tf.TensorShape([None]),
             source_length2=tf.TensorShape([]),
             text_positions2=tf.TensorShape([None]),
         ),
                        _PreparedTargetData(
                            id=tf.TensorShape([]),
                            spec=tf.TensorShape(
                                [None, self.hparams.fft_size // 2 + 1]),
                            spec_width=tf.TensorShape([]),
                            mel=tf.TensorShape(
                                [None, self.hparams.num_mels]),
                            mel_width=tf.TensorShape([]),
                            target_length=tf.TensorShape([]),
                            done=tf.TensorShape([None]),
                        )),
         padding_values=(PreparedSourceData(
             id=tf.to_int64(0),
             text="",
             source=tf.to_int64(0),
             source_length=tf.to_int64(0),
             text_positions=tf.to_int64(0),
             text2="",
             source2=tf.to_int64(0),
             source_length2=tf.to_int64(0),
             text_positions2=tf.to_int64(0),
         ),
                         _PreparedTargetData(
                             id=tf.to_int64(0),
                             spec=tf.to_float(0),
                             spec_width=tf.to_int64(0),
                             mel=tf.to_float(0),
                             mel_width=tf.to_int64(0),
                             target_length=tf.to_int64(0),
                             done=tf.to_float(1),
                         )))
     return window.apply(apply_fn)
Example #15
0
def prepare_dataset(
    dataset: tf.data.Dataset,
    model_image_size: Tuple[int, int],
    augmentation_fn: Optional[ImageDataMapFn] = None,
    num_epochs: Optional[int] = None,
    batch_size: Optional[int] = None,
    shuffle_buffer_size: Optional[int] = None,
    num_parallel_calls: Optional[int] = None,
    prefetch_buffer_size: Optional[int] = None,
    prefetch_to_device: Optional[str] = None,
) -> tf.data.Dataset:

    # apply data augmentation:
    if augmentation_fn is not None:
        dataset = dataset.map(
            map_image_data(augmentation_fn),
            num_parallel_calls=num_parallel_calls,
        )

    if shuffle_buffer_size is not None:
        dataset = dataset.shuffle(buffer_size=shuffle_buffer_size)

    dataset = dataset.repeat(num_epochs)
    dataset = dataset.map(
        map_image_data(prepare_for_batching(model_image_size)),
        num_parallel_calls=num_parallel_calls,
    )

    # batching and padding
    if batch_size is not None:
        dataset = dataset.padded_batch(
            batch_size=batch_size,
            padded_shapes=get_padding_shapes(
                dataset, spatial_image_shape=model_image_size),
            drop_remainder=True,
        )

    # try to prefetch dataset on certain device
    if prefetch_to_device is not None:
        prefetch_fn = tf.data.experimental.prefetch_to_device(
            device=prefetch_to_device, buffer_size=prefetch_buffer_size)
        dataset = dataset.apply(prefetch_fn)
    else:
        if prefetch_buffer_size is not None:
            dataset = dataset.prefetch(buffer_size=prefetch_buffer_size)

    return dataset
Example #16
0
def send_to_data_service(dataset: tf.data.Dataset,
                         compute_config: TfDataServiceConfig,
                         rank: int,
                         size: Optional[int] = None,
                         processing_mode: str = 'distributed_epoch',
                         reuse_dataset: bool = False,
                         round_robin: bool = False) -> tf.data.Dataset:
    if compute_config.dispatcher_side == 'training':
        raise RuntimeError(
            'training side dispatcher not supported, use tf_data_service context manager instead'
        )

    with tf_data_service(compute_config, rank) as dispatcher_address:
        return dataset.apply(
            tf.data.experimental.service.distribute(
                processing_mode=processing_mode,
                service=dispatcher_address,
                job_name='job' if reuse_dataset else None,
                consumer_index=rank if reuse_dataset and round_robin else None,
                num_consumers=size if reuse_dataset and round_robin else None))
Example #17
0
def get_top_tokens(corpus: tf.data.Dataset,
                   n_top: int = 1000) -> Tuple[dict, int, int]:
    """
    Builds the token mapping which is used to initialize the word embeddings in the model.
    Get the most frequent terms which appear in the training corpus.

    Parameters
    ----------
    corpus : tf.data.Dataset
        Entire dataset object
    n_top : int, optional
        Number of most frequent vocab terms to keep for training, by default 1000

    Returns
    -------
    (dict, int, int)
        (token->integer lookup, maximum sequence length, size of data set)
    """

    lookup = Counter()
    max_sequence_length, data_set_size = 0, 0

    corpus = corpus.map(lambda x: tf.strings.split(x, sep=''),
                        num_parallel_calls=tf.data.experimental.AUTOTUNE)
    for tokens_list in corpus.apply(
            tf.data.experimental.dense_to_ragged_batch(32)).prefetch(5):
        lookup.update(tokens_list.flat_values.numpy())

        max_batch_seq_len = int(tokens_list.row_lengths().numpy().max())
        if max_batch_seq_len > max_sequence_length:
            max_sequence_length = max_batch_seq_len
        data_set_size += int(tokens_list.nrows())

    # tensorflow converts strings to bytes, let's maintain that (no decoding)
    hash_map = {
        key: idx + 2
        for idx, (key, value) in enumerate(lookup.most_common(n_top))
    }
    hash_map["<s>".encode('utf8')] = 0
    hash_map["</s>".encode('utf8')] = 1
    return hash_map, max_sequence_length, data_set_size
    def prepare_dataset(
            self,  # type: ignore
            filenames: tf.data.Dataset,
            buckets: List[int],
            batch_sizes: List[int],
            shuffle: bool,
            is_holdout: bool,
            holdout_clans: set,
            holdout_families: set) -> tf.data.Dataset:
        def _check_membership(tensor, array):
            iscontained = tf.py_func(lambda t: t in array, [tensor], tf.bool)
            iscontained.set_shape(())
            return iscontained

        def _filter_fn(example):
            is_holdout_example = \
                _check_membership(example['clan'], holdout_clans) | \
                _check_membership(example['family'], holdout_families)
            return ~(is_holdout ^ is_holdout_example)

        def _load_records_and_preprocess(fname: tf.Tensor):
            dataset = tf.data.TFRecordDataset(fname)
            dataset = dataset.map(self._deserialization_func)
            # Hold out a prespecified set of families and clans
            dataset = dataset.filter(_filter_fn)
            return dataset

        dataset = filenames.apply(
            tf.data.experimental.parallel_interleave(
                _load_records_and_preprocess,
                sloppy=True,
                cycle_length=128,
                buffer_output_elements=32))

        dataset = dataset.shuffle(1024) if shuffle else dataset.prefetch(1024)
        batch_fun = tf.data.experimental.bucket_by_sequence_length(
            lambda example: example['protein_length'], buckets, batch_sizes)
        dataset = dataset.apply(batch_fun)
        return dataset
Example #19
0
def read_dataset(dataset: tf.data.Dataset) -> Tuple[float, int]:
    dataset = dataset.apply(
        tf.data.experimental.map_and_batch(dataset_parser,
                                           batch_size=1,
                                           num_parallel_batches=2,
                                           drop_remainder=True))
    dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
    next_element_from_dataset = dataset.make_one_shot_iterator().get_next()

    with tf.Session() as sess:
        data_samples = 0
        dataset_read_start_time = time.time()

        while True:
            try:
                sess.run(next_element_from_dataset)
                data_samples += 1
            except tf.errors.OutOfRangeError:
                break

        dataset_read_time = time.time() - dataset_read_start_time

    return dataset_read_time, data_samples
Example #20
0
    def pipeline(self, dataset: tf.data.Dataset) -> tf.data.Dataset:
        """Build a pipeline fetching, shuffling, and preprocessing the dataset.

    Args:
      dataset: A `tf.data.Dataset` that loads raw files.

    Returns:
      A TensorFlow dataset outputting batched images and labels.
    """
        if (self.config.builder != 'tfds' and self.input_context
                and self.input_context.num_input_pipelines > 1):
            dataset = dataset.shard(self.input_context.num_input_pipelines,
                                    self.input_context.input_pipeline_id)
            logging.info(
                'Sharding the dataset: input_pipeline_id=%d '
                'num_input_pipelines=%d',
                self.input_context.num_input_pipelines,
                self.input_context.input_pipeline_id)

        if self.is_training and self.config.builder == 'records':
            # Shuffle the input files.
            dataset.shuffle(buffer_size=self.config.file_shuffle_buffer_size)

        if self.is_training and not self.config.cache:
            dataset = dataset.repeat()

        if self.config.builder == 'records':
            # Read the data from disk in parallel
            dataset = dataset.interleave(
                tf.data.TFRecordDataset,
                cycle_length=10,
                block_length=1,
                num_parallel_calls=tf.data.experimental.AUTOTUNE)

        if self.config.cache:
            dataset = dataset.cache()

        if self.is_training:
            dataset = dataset.shuffle(self.config.shuffle_buffer_size)
            dataset = dataset.repeat()

        # Parse, pre-process, and batch the data in parallel
        if self.config.builder == 'records':
            preprocess = self.parse_record
        else:
            preprocess = self.preprocess
        dataset = dataset.map(preprocess,
                              num_parallel_calls=tf.data.experimental.AUTOTUNE)

        if self.input_context and self.config.num_devices > 1:
            if not self.config.use_per_replica_batch_size:
                raise ValueError(
                    'The builder does not support a global batch size with more than '
                    'one replica. Got {} replicas. Please set a '
                    '`per_replica_batch_size` and enable '
                    '`use_per_replica_batch_size=True`.'.format(
                        self.config.num_devices))

            # The batch size of the dataset will be multiplied by the number of
            # replicas automatically when strategy.distribute_datasets_from_function
            # is called, so we use local batch size here.
            dataset = dataset.batch(self.local_batch_size,
                                    drop_remainder=self.is_training)
        else:
            dataset = dataset.batch(self.global_batch_size,
                                    drop_remainder=self.is_training)

        # Prefetch overlaps in-feed with training
        dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)

        if self.config.tf_data_service:
            if not hasattr(tf.data.experimental, 'service'):
                raise ValueError(
                    'The tf_data_service flag requires Tensorflow version '
                    '>= 2.3.0, but the version is {}'.format(tf.__version__))
            dataset = dataset.apply(
                tf.data.experimental.service.distribute(
                    processing_mode='parallel_epochs',
                    service=self.config.tf_data_service,
                    job_name='resnet_train'))
            dataset = dataset.prefetch(
                buffer_size=tf.data.experimental.AUTOTUNE)

        return dataset
Example #21
0
 def __call__(self, dataset: tf.data.Dataset) -> tf.data.Dataset:
     return dataset.apply(
         tf.data.experimental.dense_to_ragged_batch(self._batch_size,
                                                    self._drop_remainder))
Example #22
0
def get_bucketed_batches(
    dataset: tf.data.Dataset,
    batch_size: int,
    bucket_size: int,
    max_length: int,
    padded_shapes: Any,
    example_size_fn: Any,
    shuffle: bool = False,
    shuffle_seed: Optional[int] = None,
    drop_remainder: bool = False,
) -> tf.data.Dataset:
    """Returns padded batches of shuffled examples bucketed by length.

  This shuffles the examples randomly each epoch. The random order is
  deterministic and controlled by the seed.

  Batches are padded because sentences have different lengths.
  Sentences that are shorter in a batch will get 0s added at the end, until
  all sentences in the batch have the same length.

  For performance, examples of similar lengths are bucketed together. However,
  the contents of the buckets and their order is random each epoch, and
  controlled by the seed.

  Args:
    dataset: A TF Dataset with SST examples to be shuffled and batched.
    batch_size: The size of each batch. The remainder is dropped.
    bucket_size: How many different lengths go in each bucket.
    max_length: The maximum length to provide a bucket for.
    padded_shapes: A nested structure representing the shape to which the
      respective component of each input element should be padded prior to
      batching. See `tf.data.Dataset.padded_batch` for examples.
    example_size_fn: A TF function that returns the size of an example to
      determine in which bucket it goes. E.g., the sentence length.
    shuffle: Shuffle the dataset each epoch using seed.
    shuffle_seed: The seed that determines the shuffling order, with a
      different order each epoch.
    drop_remainder: Drop the last batch if it is not of size batch_size.

  Returns:
    A TF Dataset containing padded bucketed batches.
  """
    if shuffle:
        assert shuffle_seed is not None, 'When shuffling you must provide a seed.'

    # For bucket_size 8 and max length 24, we get bucket boundaries [9, 17, 25].
    bucket_boundaries = get_bucket_boundaries(bucket_size, max_length)
    logging.info('Batching bucket boundaries: %r', bucket_boundaries)

    # One batch size for each bucket plus one additional one (per requirement).
    bucket_batch_sizes = [batch_size] * (len(bucket_boundaries) + 1)
    bucket_fn = tf.data.experimental.bucket_by_sequence_length(
        example_size_fn,
        bucket_boundaries,
        bucket_batch_sizes,
        padded_shapes=padded_shapes,
        pad_to_bucket_boundary=True,
        drop_remainder=drop_remainder)

    if shuffle:
        # For shuffling we need to know how many training examples we have.
        num_examples = get_num_examples(dataset)
        num_batches = num_examples // batch_size
        return dataset.shuffle(
            num_examples, seed=shuffle_seed,
            reshuffle_each_iteration=True).apply(bucket_fn).shuffle(
                num_batches, seed=shuffle_seed,
                reshuffle_each_iteration=True).prefetch(
                    tf.data.experimental.AUTOTUNE)
    return dataset.apply(bucket_fn).prefetch(tf.data.experimental.AUTOTUNE)
Example #23
0
def batch(dataset: tf.data.Dataset, batch_size: int):
    return dataset.apply(
        tf.data.experimental.dense_to_ragged_batch(batch_size))
Example #24
0
def build_meta_model_trainable(
    meta_model_func: Callable,
    train_dataset: tf.data.Dataset,
    validation_dataset: tf.data.Dataset,
    batcher: Transform,
    shuffle_buffer: int,
    compiler: Optional[Callable[[tf.keras.Model], Any]] = None,
    cache_factory: Optional[Callable[[str], Transform]] = None,
    cache_dir: Optional[str] = None,
    cache_repeats: Optional[int] = None,
    train_augment_func: Optional[Callable] = None,
    validation_augment_func: Optional[Callable] = None,
    callbacks: Iterable[tf.keras.callbacks.Callback] = (),
    seed: Optional[int] = None,
):
    # this version DOES bleed examples from 1 train epoch to the next
    # via a final full batch and shuffle buffer
    # Hopefully it gets rid of the memory leak we see?
    if cache_factory is not None:
        assert cache_dir is not None
        assert cache_repeats is not None
        cache_dir = expand(cache_dir)
    else:
        assert cache_dir is None
        assert cache_repeats is None

    if validation_augment_func:
        validation_dataset = validation_dataset.map(
            validation_augment_func, num_parallel_calls=AUTOTUNE)
    pipeline, model = pl.build_pipelined_model(meta_model_func,
                                               validation_dataset.element_spec,
                                               batcher)
    if compiler is not None:
        compiler(model)

    # finalize validation_dataset
    pre_cache, pre_batch, post_batch = _get_map_funcs(pipeline, training=False)
    validation_dataset = (validation_dataset.map(
        chain_map_funcs(pre_cache, pre_batch),
        AUTOTUNE).apply(batcher).map(post_batch, AUTOTUNE))
    if cache_factory is not None:
        validation_dataset = validation_dataset.apply(
            cache_factory(os.path.join(cache_dir, "validation")))

    # train_data
    steps_per_epoch = tf.keras.backend.get_value(
        train_dataset.apply(batcher).cardinality())
    pre_cache, pre_batch, post_batch = _get_map_funcs(pipeline, training=True)

    if cache_factory is None:
        train_dataset = train_dataset.repeat().shuffle(shuffle_buffer,
                                                       seed=seed)
        if train_augment_func is None:
            train_dataset = (train_dataset.map(
                chain_map_funcs(pre_cache, pre_batch),
                num_parallel_calls=AUTOTUNE,
            ).apply(batcher).map(post_batch, AUTOTUNE))
        else:
            # cache_factory is None, train_augment_func is not
            train_dataset = (train_dataset.repeat().apply(
                tfrng.data.stateless_map(
                    chain_map_funcs(train_augment_func, pre_cache, pre_batch),
                    seed=seed,
                    num_parallel_calls=AUTOTUNE,
                )).apply(batcher).map(post_batch, AUTOTUNE))
    else:
        # cache_factory is not None
        # We create separately cached datasets and flat_map over them.
        # This allows us to reuse the same caches if we want to change cache_repeats
        assert cache_repeats < 1e4  # for unique path naming
        paths = [
            os.path.join(cache_dir, "train", f"repeat-{i:04d}")
            for i in range(cache_repeats)
        ]
        if train_augment_func is None:
            # No augmentation
            assert cache_repeats == 1
            (path, ) = paths
            train_dataset = (train_dataset.map(
                pre_cache, num_parallel_calls=AUTOTUNE).apply(
                    cache_factory(path)).repeat())
        else:

            def get_cached(epoch_seed, path):
                return train_dataset.apply(
                    tfrng.data.stateless_map(
                        chain_map_funcs(train_augment_func, pre_cache),
                        seed=epoch_seed,
                        num_parallel_calls=AUTOTUNE,
                    )).apply(cache_factory(path))

            # train_dataset = (
            #     paths.apply(tfrng.data.with_seed(seed))
            #     .repeat()
            #     .flat_map(get_cached)
            # )
            # create cached datasets in eager mode.
            # For some cache_factory implementations this will mean all `cache_repeat`
            # cache files are run ahead of time.
            # This allows us to save iterators for training.
            datasets = [
                get_cached(s, p) for s, p in zip(
                    tf.data.experimental.RandomDataset(seed), paths)
            ]
            train_dataset = (tf.data.Dataset.from_tensor_slices(
                datasets).repeat().flat_map(lambda ds: ds))
        train_dataset = (
            train_dataset.shuffle(shuffle_buffer,
                                  seed=seed).map(pre_batch,
                                                 num_parallel_calls=AUTOTUNE).
            apply(batcher).map(post_batch, num_parallel_calls=AUTOTUNE).repeat(
            )  # HACK: because the assert_cardinality below raises on iteration
            # .apply(tf.data.experimental.assert_cardinality(
            #     tf.data.INFINITE_CARDINALITY))
        )
        # https://github.com/tensorflow/tensorflow/issues/45894

    # assert specs the same
    train_spec = train_dataset.element_spec
    val_spec = validation_dataset.element_spec
    tf.nest.assert_same_structure(train_spec, val_spec)
    flat_train = tf.nest.flatten(train_spec)
    flat_val = tf.nest.flatten(val_spec)
    for t, v in zip(flat_train, flat_val):
        assert t == v

    return Trainable(
        model=model,
        train_data=train_dataset,
        steps_per_epoch=steps_per_epoch,
        validation_data=validation_dataset,
        callbacks=tuple(callbacks),
    )
Example #25
0
 def transform(dataset: tf.data.Dataset) -> tf.data.Dataset:
     return dataset.apply(with_seed(seed, 2)).map(
         actual_map_func,
         num_parallel_calls=num_parallel_calls,
         deterministic=deterministic,
     )