コード例 #1
0
  def _get_input_iterator(
      self, input_fn: Callable[..., tf.data.Dataset],
      strategy: tf.distribute.Strategy) -> Optional[Iterator[Any]]:
    """Returns distributed dataset iterator.

    Args:
      input_fn: (params: dict) -> tf.data.Dataset.
      strategy: an instance of tf.distribute.Strategy.

    Returns:
      An iterator that yields input tensors.
    """

    if input_fn is None:
      return None
    # When training with multiple TPU workers, datasets needs to be cloned
    # across workers. Since Dataset instance cannot be cloned in eager mode,
    # we instead pass callable that returns a dataset.
    if self._is_multi_host:
      return iter(strategy.distribute_datasets_from_function(input_fn))
    else:
      input_data = input_fn()
      return iter(strategy.experimental_distribute_dataset(input_data))
コード例 #2
0
    def build(self,
              strategy: tf.distribute.Strategy = None) -> tf.data.Dataset:
        """Construct a dataset end-to-end and return it using an optional strategy.

    Args:
      strategy: a strategy that, if passed, will distribute the dataset
        according to that strategy. If passed and `num_devices > 1`,
        `use_per_replica_batch_size` must be set to `True`.

    Returns:
      A TensorFlow dataset outputting batched images and labels.
    """
        if strategy:
            if strategy.num_replicas_in_sync != self.config.num_devices:
                logging.warn(
                    'Passed a strategy with %d devices, but expected'
                    '%d devices.', strategy.num_replicas_in_sync,
                    self.config.num_devices)
            dataset = strategy.distribute_datasets_from_function(self._build)
        else:
            dataset = self._build()

        return dataset