Esempio n. 1
0
def get_datasets(train_batch_size: int,
                 val_batch_size: int,
                 strategy: tf.distribute.Strategy) -> Tuple[Any, Any]:
  """Create and return train and validation dataset builders."""
  ds_train = strategy.experimental_distribute_datasets_from_function(
      _make_get_dataset_fn('train', train_batch_size, True))
  ds_val = strategy.experimental_distribute_datasets_from_function(
      _make_get_dataset_fn('validation', val_batch_size, False))

  return ds_train, ds_val
Esempio n. 2
0
    def _get_input_iterator(
            self, input_fn: Callable[[Optional[params_dict.ParamsDict]],
                                     tf.data.Dataset],
            strategy: tf.distribute.Strategy) -> Optional[Iterator[Any]]:
        """Returns distributed dataset iterator.

    Args:
      input_fn: (params: dict) -> tf.data.Dataset.
      strategy: an instance of tf.distribute.Strategy.

    Returns:
      An iterator that yields input tensors.
    """

        if input_fn is None:
            return None
        # When training with multiple TPU workers, datasets needs to be cloned
        # across workers. Since Dataset instance cannot be cloned in eager mode,
        # we instead pass callable that returns a dataset.
        if self._is_multi_host:
            return iter(
                strategy.experimental_distribute_datasets_from_function(
                    input_fn))
        else:
            input_data = input_fn(self._params)
            return iter(strategy.experimental_distribute_dataset(input_data))
Esempio n. 3
0
    def build(self,
              strategy: tf.distribute.Strategy = None) -> tf.data.Dataset:
        """Construct a dataset end-to-end and return it using an optional strategy.

    Args:
      strategy: a strategy that, if passed, will distribute the dataset
        according to that strategy. If passed and `num_devices > 1`,
        `use_per_replica_batch_size` must be set to `True`.

    Returns:
      A TensorFlow dataset outputting batched images and labels.
    """
        if strategy:
            if strategy.num_replicas_in_sync != self.config.num_devices:
                logging.warn(
                    'Passed a strategy with %d devices, but expected'
                    '%d devices.', strategy.num_replicas_in_sync,
                    self.config.num_devices)

            dataset = strategy.experimental_distribute_datasets_from_function(
                self._build)
        else:
            dataset = self._build()

        return dataset