Example #1
0
  def __init__(self,
               dataset,
               input_workers,
               strategy,
               split_batch_by=None,
               input_context=None,
               **kwargs):
    """Distribute the dataset on all workers.

    If `split_batch_by` is not None, we "split" each batch of the dataset by
    `split_batch_by` value.

    Args:
      dataset: `tf.data.Dataset` that will be used as the input source.
      input_workers: an `InputWorkers` object.
      strategy: a `tf.distribute.Strategy` object, used to run all-reduce to
        handle last partial batch.
      split_batch_by: Optional integer. If present, we "split" each batch of the
        dataset by `split_batch_by` value.
      input_context: `InputContext` for sharding. Only pass this in for between
        graph multi-worker cases where there is only one `input_worker`. In
        these cases, we will shard based on the `input_pipeline_id` and
        `num_input_pipelines` in the `InputContext`.
      **kwargs: Additional experimental flags. Will be removed in future.
    """
    # We clone and shard the dataset on each worker. The current setup tries to
    # shard the dataset by files if possible so that each worker sees a
    # different subset of files. If that is not possible, will attempt to shard
    # the final input such that each worker will run the entire preprocessing
    # pipeline and only receive its own shard of the dataset.
    assert isinstance(input_workers, InputWorkers)
    if split_batch_by:
      dataset = distribute._RebatchDataset(dataset, split_batch_by)  # pylint: disable=protected-access

    self._cloned_datasets = []
    if input_context:
      # Between-graph where we rely on the input_context for sharding
      assert input_workers.num_workers == 1
      dataset = input_ops.auto_shard_dataset(  # pylint: disable=protected-access
          dataset, input_context.num_input_pipelines,
          input_context.input_pipeline_id)
      self._cloned_datasets.append(dataset)
    else:
      for i, worker in enumerate(input_workers.worker_devices):
        with ops.device(worker):
          cloned_dataset = dataset
          if not context.executing_eagerly():
            cloned_dataset = input_ops._clone_dataset(dataset)  # pylint: disable=protected-access
            cloned_dataset = cloned_dataset.with_options(dataset.options())
          # TODO(b/129506833): Figure out between graph cases
          cloned_dataset = input_ops.auto_shard_dataset(  # pylint: disable=protected-access
              cloned_dataset, len(input_workers.worker_devices), i)
          self._cloned_datasets.append(cloned_dataset)

    self._input_workers = input_workers
    # TODO(anjalisridhar): Identify if we need to set this property on the
    # iterator.
    self._element_structure = dataset._element_structure  # pylint: disable=protected-access
    self._strategy = strategy
    self._kwargs = kwargs
Example #2
0
  def __init__(self,
               dataset,
               input_workers,
               strategy,
               split_batch_by=None,
               input_context=None,
               **kwargs):
    """Distribute the dataset on all workers.

    If `split_batch_by` is not None, we "split" each batch of the dataset by
    `split_batch_by` value.

    Args:
      dataset: `tf.data.Dataset` that will be used as the input source.
      input_workers: an `InputWorkers` object.
      strategy: a `tf.distribute.Strategy` object, used to run all-reduce to
        handle last partial batch.
      split_batch_by: Optional integer. If present, we "split" each batch of the
        dataset by `split_batch_by` value.
      input_context: `InputContext` for sharding. Only pass this in for between
        graph multi-worker cases where there is only one `input_worker`. In
        these cases, we will shard based on the `input_pipeline_id` and
        `num_input_pipelines` in the `InputContext`.
      **kwargs: Additional experimental flags. Will be removed in future.
    """
    # We clone and shard the dataset on each worker. The current setup tries to
    # shard the dataset by files if possible so that each worker sees a
    # different subset of files. If that is not possible, will attempt to shard
    # the final input such that each worker will run the entire preprocessing
    # pipeline and only receive its own shard of the dataset.
    assert isinstance(input_workers, InputWorkers)
    if split_batch_by:
      dataset = distribute._RebatchDataset(dataset, split_batch_by)  # pylint: disable=protected-access

    self._cloned_datasets = []
    if input_context:
      # Between-graph where we rely on the input_context for sharding
      assert input_workers.num_workers == 1
      dataset = input_ops.auto_shard_dataset(  # pylint: disable=protected-access
          dataset, input_context.num_input_pipelines,
          input_context.input_pipeline_id)
      self._cloned_datasets.append(dataset)
    else:
      for i, worker in enumerate(input_workers.worker_devices):
        with ops.device(worker):
          cloned_dataset = dataset
          if not context.executing_eagerly():
            cloned_dataset = input_ops._clone_dataset(dataset)  # pylint: disable=protected-access
            cloned_dataset = cloned_dataset.with_options(dataset.options())
          # TODO(b/129506833): Figure out between graph cases
          cloned_dataset = input_ops.auto_shard_dataset(  # pylint: disable=protected-access
              cloned_dataset, len(input_workers.worker_devices), i)
          self._cloned_datasets.append(cloned_dataset)

    self._input_workers = input_workers
    # TODO(anjalisridhar): Identify if we need to set this property on the
    # iterator.
    self._element_structure = dataset._element_structure  # pylint: disable=protected-access
    self._strategy = strategy
    self._kwargs = kwargs
Example #3
0
  def __init__(self, dataset, input_workers, split_batch_by=None, **kwargs):
    """Make an iterator for the dataset on given devices.

    If `split_batch_by` is not None, we "split" each batch of the
    dataset by `split_batch_by` value. To achieve this, we first unbatch the
    input dataset and then rebatch it with the per replica batch size that is
    calculated using `global_batch_size // split_batch_by`.
    The currently supported datasets are as follows:
    `dataset.batch()` is the last operation on the dataset OR
    `dataset.apply(map_and_batch)` is the last operation on the dataset OR
    `dataset.batch().prefetch()` are the last 2 operations on the dataset OR
    `dataset.apply(map_and_batch).prefetch()` are the last 2 operations.

    We clone and shard the dataset on each worker. The current setup tries to
    shard the dataset by files if possible so that each worker sees a different
    subset of files. If that is not possible, will attempt to shard the final
    input such that each worker will run the entire preprocessing pipeline and
    only receive its own shard of the dataset.

    Args:
      dataset: `tf.data.Dataset` that will be used as the input source.
      input_workers: an `InputWorkers` object.
      split_batch_by: Optional integer. If present, we "split" each batch of the
        dataset by `split_batch_by` value.
      **kwargs: Additional experimental flags. Will be removed in future.
    """
    assert isinstance(input_workers, InputWorkers)
    if split_batch_by:
      dataset = batching._RebatchDataset(dataset, split_batch_by)  # pylint: disable=protected-access

    iterators = []
    for i, worker in enumerate(input_workers.worker_devices):
      with ops.device(worker):
        worker_devices = input_workers.compute_devices_for_worker(i)
        cloned_dataset = dataset
        if not context.executing_eagerly():
          cloned_dataset = input_ops._clone_dataset(dataset)  # pylint: disable=protected-access
          cloned_dataset = cloned_dataset.with_options(dataset.options())
        # TODO(b/129506833): Figure out between graph cases
        cloned_dataset = input_ops.auto_shard_dataset(  # pylint: disable=protected-access
            cloned_dataset, len(input_workers.worker_devices), i)
        iterator = _SingleWorkerDatasetIterator(cloned_dataset, worker,
                                                worker_devices)
        iterators.append(iterator)

    self._element_structure = dataset._element_structure  # pylint: disable=protected-access

    super(DatasetIterator, self).__init__(input_workers, iterators, **kwargs)
Example #4
0
    def __init__(self, dataset, input_workers, split_batch_by=None, **kwargs):
        """Make an iterator for the dataset on given devices.

    If `split_batch_by` is not None, we "split" each batch of the
    dataset by `split_batch_by` value. To achieve this, we first unbatch the
    input dataset and then rebatch it with the per replica batch size that is
    calculated using `global_batch_size // split_batch_by`.
    The currently supported datasets are as follows:
    `dataset.batch()` is the last operation on the dataset OR
    `dataset.apply(map_and_batch)` is the last operation on the dataset OR
    `dataset.batch().prefetch()` are the last 2 operations on the dataset OR
    `dataset.apply(map_and_batch).prefetch()` are the last 2 operations.

    TODO(priyag): Support multi worker / host cases properly by cloning
    and sharding the dataset on each worker. Current setup will only work in
    some cases, such as in-graph multi worker GPU case. If the input pipeline
    has random shuffling (with a different seed on each worker), each worker
    will see random input from the same overall dataset in each step. Otherwise,
    each worker will see the same input in each step.

    Args:
      dataset: `tf.data.Dataset` that will be used as the input source.
      input_workers: an `InputWorkers` object.
      split_batch_by: Optional integer. If present, we "split" each batch of the
        dataset by `split_batch_by` value.
      **kwargs: Additional experimental flags. Will be removed in future.
    """
        assert isinstance(input_workers, InputWorkers)
        if split_batch_by:
            dataset = batching._RebatchDataset(dataset, split_batch_by)  # pylint: disable=protected-access

        iterators = []
        for i, worker in enumerate(input_workers.worker_devices):
            with ops.device(worker):
                worker_devices = input_workers.compute_devices_for_worker(i)
                cloned_dataset = dataset
                if not context.executing_eagerly():
                    cloned_dataset = input_ops._clone_dataset(dataset)  # pylint: disable=protected-access
                    cloned_dataset = cloned_dataset.with_options(
                        dataset.options())
                iterator = _SingleWorkerDatasetIterator(
                    cloned_dataset, worker, worker_devices)
                iterators.append(iterator)

        self._element_structure = dataset._element_structure  # pylint: disable=protected-access

        super(DatasetIterator, self).__init__(input_workers, iterators,
                                              **kwargs)
Example #5
0
  def __init__(self, dataset, input_workers, split_batch_by=None, **kwargs):
    """Make an iterator for the dataset on given devices.

    If `split_batch_by` is not None, we "split" each batch of the
    dataset by `split_batch_by` value. To achieve this, we first unbatch the
    input dataset and then rebatch it with the per replica batch size that is
    calculated using `global_batch_size // split_batch_by`.
    The currently supported datasets are as follows:
    `dataset.batch()` is the last operation on the dataset OR
    `dataset.apply(map_and_batch)` is the last operation on the dataset OR
    `dataset.batch().prefetch()` are the last 2 operations on the dataset OR
    `dataset.apply(map_and_batch).prefetch()` are the last 2 operations.

    TODO(priyag): Support multi worker / host cases properly by cloning
    and sharding the dataset on each worker. Current setup will only work in
    some cases, such as in-graph multi worker GPU case. If the input pipeline
    has random shuffling (with a different seed on each worker), each worker
    will see random input from the same overall dataset in each step. Otherwise,
    each worker will see the same input in each step.

    Args:
      dataset: `tf.data.Dataset` that will be used as the input source.
      input_workers: an `InputWorkers` object.
      split_batch_by: Optional integer. If present, we "split" each batch of the
        dataset by `split_batch_by` value.
      **kwargs: Additional experimental flags. Will be removed in future.
    """
    assert isinstance(input_workers, InputWorkers)
    if split_batch_by:
      dataset = batching._RebatchDataset(dataset, split_batch_by)  # pylint: disable=protected-access

    iterators = []
    for i, worker in enumerate(input_workers.worker_devices):
      with ops.device(worker):
        worker_devices = input_workers.compute_devices_for_worker(i)
        cloned_dataset = dataset
        if not context.executing_eagerly():
          cloned_dataset = input_ops._clone_dataset(dataset)  # pylint: disable=protected-access
          cloned_dataset = cloned_dataset.with_options(dataset.options())
        iterator = _SingleWorkerDatasetIterator(cloned_dataset, worker,
                                                worker_devices)
        iterators.append(iterator)

    self._element_structure = dataset._element_structure  # pylint: disable=protected-access

    super(DatasetIterator, self).__init__(input_workers, iterators, **kwargs)
 def testMultipleVariantTensors(self):
   ds = dataset_ops.Dataset.range(10)
   ds = _TestDataset(ds)
   cloned_ds = input_ops._clone_dataset(ds)
   self._assert_datasets_equal(ds, cloned_ds)
 def testZip(self):
   ds1 = dataset_ops.Dataset.range(10)
   ds2 = dataset_ops.Dataset.range(10)
   ds = dataset_ops.Dataset.zip((ds1, ds2))
   cloned_ds = input_ops._clone_dataset(ds)
   self._assert_datasets_equal(ds, cloned_ds)
 def testConcat(self):
   ds1 = dataset_ops.Dataset.range(10)
   ds2 = dataset_ops.Dataset.range(10)
   ds = ds1.concatenate(ds2)
   cloned_ds = input_ops._clone_dataset(ds)
   self._assert_datasets_equal(ds, cloned_ds)
 def testSimplePipeline(self):
   ds = dataset_ops.Dataset.range(10).map(math_ops.square)
   cloned_ds = input_ops._clone_dataset(ds)
   self._assert_datasets_equal(ds, cloned_ds)
Example #10
0
 def testOnlySource(self):
   ds = dataset_ops.Dataset.range(10)
   cloned_ds = input_ops._clone_dataset(ds)
   self._assert_datasets_equal(ds, cloned_ds)
Example #11
0
    def __init__(self,
                 dataset,
                 input_workers,
                 strategy,
                 split_batch_by=None,
                 input_context=None):
        """Distribute the dataset on all workers.

    If `split_batch_by` is not None, we "split" each batch of the dataset by
    `split_batch_by` value.

    Args:
      dataset: `tf.data.Dataset` that will be used as the input source.
      input_workers: an `InputWorkers` object.
      strategy: a `tf.distribute.Strategy` object, used to run all-reduce to
        handle last partial batch.
      split_batch_by: Optional integer. If present, we "split" each batch of the
        dataset by `split_batch_by` value.
      input_context: `InputContext` for sharding. Only pass this in for between
        graph multi-worker cases where there is only one `input_worker`. In
        these cases, we will shard based on the `input_pipeline_id` and
        `num_input_pipelines` in the `InputContext`.
    """
        super(DistributedDataset, self).__init__(input_workers=input_workers)

        # We clone and shard the dataset on each worker. The current setup tries to
        # shard the dataset by files if possible so that each worker sees a
        # different subset of files. If that is not possible, will attempt to shard
        # the final input such that each worker will run the entire preprocessing
        # pipeline and only receive its own shard of the dataset.
        if split_batch_by:
            try:
                dataset = distribute._RebatchDataset(dataset, split_batch_by)  # pylint: disable=protected-access
            except errors.InvalidArgumentError as e:
                if "without encountering a batch" in str(e):
                    six.reraise(
                        ValueError,
                        ValueError(
                            "Call the `batch` method on the input Dataset in order to be "
                            "able to split your input across {} replicas.\n Please "
                            "the tf.distribute.Strategy guide. {}".format(
                                split_batch_by, e)),
                        sys.exc_info()[2])
                else:
                    raise

        self._cloned_datasets = []
        if input_context:
            # Between-graph where we rely on the input_context for sharding
            assert input_workers.num_workers == 1
            dataset = input_ops.auto_shard_dataset(  # pylint: disable=protected-access
                dataset, input_context.num_input_pipelines,
                input_context.input_pipeline_id)
            self._cloned_datasets.append(dataset)
        else:
            for i, worker in enumerate(input_workers.worker_devices):
                with ops.device(worker):
                    cloned_dataset = dataset
                    if not context.executing_eagerly():
                        cloned_dataset = input_ops._clone_dataset(dataset)  # pylint: disable=protected-access
                        cloned_dataset = cloned_dataset.with_options(
                            dataset.options())
                    # TODO(b/129506833): Figure out between graph cases
                    cloned_dataset = input_ops.auto_shard_dataset(  # pylint: disable=protected-access
                        cloned_dataset, len(input_workers.worker_devices), i)
                    self._cloned_datasets.append(cloned_dataset)

        self._input_workers = input_workers
        # TODO(anjalisridhar): Identify if we need to set this property on the
        # iterator.
        self._element_structure = dataset._element_structure  # pylint: disable=protected-access
        self._strategy = strategy
Example #12
0
 def testMultipleVariantTensors(self):
   ds = dataset_ops.Dataset.range(10)
   ds = _TestDataset(ds)
   cloned_ds = input_ops._clone_dataset(ds)
   self._assert_datasets_equal(ds, cloned_ds)
Example #13
0
 def testZip(self):
   ds1 = dataset_ops.Dataset.range(10)
   ds2 = dataset_ops.Dataset.range(10)
   ds = dataset_ops.Dataset.zip((ds1, ds2))
   cloned_ds = input_ops._clone_dataset(ds)
   self._assert_datasets_equal(ds, cloned_ds)
Example #14
0
 def testConcat(self):
   ds1 = dataset_ops.Dataset.range(10)
   ds2 = dataset_ops.Dataset.range(10)
   ds = ds1.concatenate(ds2)
   cloned_ds = input_ops._clone_dataset(ds)
   self._assert_datasets_equal(ds, cloned_ds)
Example #15
0
 def testSimplePipeline(self):
   ds = dataset_ops.Dataset.range(10).map(math_ops.square)
   cloned_ds = input_ops._clone_dataset(ds)
   self._assert_datasets_equal(ds, cloned_ds)
Example #16
0
 def testOnlySource(self):
   ds = dataset_ops.Dataset.range(10)
   cloned_ds = input_ops._clone_dataset(ds)
   self._assert_datasets_equal(ds, cloned_ds)