def __init__(self, dataset, input_workers, strategy, split_batch_by=None, input_context=None, **kwargs): """Distribute the dataset on all workers. If `split_batch_by` is not None, we "split" each batch of the dataset by `split_batch_by` value. Args: dataset: `tf.data.Dataset` that will be used as the input source. input_workers: an `InputWorkers` object. strategy: a `tf.distribute.Strategy` object, used to run all-reduce to handle last partial batch. split_batch_by: Optional integer. If present, we "split" each batch of the dataset by `split_batch_by` value. input_context: `InputContext` for sharding. Only pass this in for between graph multi-worker cases where there is only one `input_worker`. In these cases, we will shard based on the `input_pipeline_id` and `num_input_pipelines` in the `InputContext`. **kwargs: Additional experimental flags. Will be removed in future. """ # We clone and shard the dataset on each worker. The current setup tries to # shard the dataset by files if possible so that each worker sees a # different subset of files. If that is not possible, will attempt to shard # the final input such that each worker will run the entire preprocessing # pipeline and only receive its own shard of the dataset. assert isinstance(input_workers, InputWorkers) if split_batch_by: dataset = distribute._RebatchDataset(dataset, split_batch_by) # pylint: disable=protected-access self._cloned_datasets = [] if input_context: # Between-graph where we rely on the input_context for sharding assert input_workers.num_workers == 1 dataset = input_ops.auto_shard_dataset( # pylint: disable=protected-access dataset, input_context.num_input_pipelines, input_context.input_pipeline_id) self._cloned_datasets.append(dataset) else: for i, worker in enumerate(input_workers.worker_devices): with ops.device(worker): cloned_dataset = dataset if not context.executing_eagerly(): cloned_dataset = input_ops._clone_dataset(dataset) # pylint: disable=protected-access cloned_dataset = cloned_dataset.with_options(dataset.options()) # TODO(b/129506833): Figure out between graph cases cloned_dataset = input_ops.auto_shard_dataset( # pylint: disable=protected-access cloned_dataset, len(input_workers.worker_devices), i) self._cloned_datasets.append(cloned_dataset) self._input_workers = input_workers # TODO(anjalisridhar): Identify if we need to set this property on the # iterator. self._element_structure = dataset._element_structure # pylint: disable=protected-access self._strategy = strategy self._kwargs = kwargs
def __init__(self, dataset, input_workers, split_batch_by=None, **kwargs): """Make an iterator for the dataset on given devices. If `split_batch_by` is not None, we "split" each batch of the dataset by `split_batch_by` value. To achieve this, we first unbatch the input dataset and then rebatch it with the per replica batch size that is calculated using `global_batch_size // split_batch_by`. The currently supported datasets are as follows: `dataset.batch()` is the last operation on the dataset OR `dataset.apply(map_and_batch)` is the last operation on the dataset OR `dataset.batch().prefetch()` are the last 2 operations on the dataset OR `dataset.apply(map_and_batch).prefetch()` are the last 2 operations. We clone and shard the dataset on each worker. The current setup tries to shard the dataset by files if possible so that each worker sees a different subset of files. If that is not possible, will attempt to shard the final input such that each worker will run the entire preprocessing pipeline and only receive its own shard of the dataset. Args: dataset: `tf.data.Dataset` that will be used as the input source. input_workers: an `InputWorkers` object. split_batch_by: Optional integer. If present, we "split" each batch of the dataset by `split_batch_by` value. **kwargs: Additional experimental flags. Will be removed in future. """ assert isinstance(input_workers, InputWorkers) if split_batch_by: dataset = batching._RebatchDataset(dataset, split_batch_by) # pylint: disable=protected-access iterators = [] for i, worker in enumerate(input_workers.worker_devices): with ops.device(worker): worker_devices = input_workers.compute_devices_for_worker(i) cloned_dataset = dataset if not context.executing_eagerly(): cloned_dataset = input_ops._clone_dataset(dataset) # pylint: disable=protected-access cloned_dataset = cloned_dataset.with_options(dataset.options()) # TODO(b/129506833): Figure out between graph cases cloned_dataset = input_ops.auto_shard_dataset( # pylint: disable=protected-access cloned_dataset, len(input_workers.worker_devices), i) iterator = _SingleWorkerDatasetIterator(cloned_dataset, worker, worker_devices) iterators.append(iterator) self._element_structure = dataset._element_structure # pylint: disable=protected-access super(DatasetIterator, self).__init__(input_workers, iterators, **kwargs)
def __init__(self, dataset, input_workers, split_batch_by=None, **kwargs): """Make an iterator for the dataset on given devices. If `split_batch_by` is not None, we "split" each batch of the dataset by `split_batch_by` value. To achieve this, we first unbatch the input dataset and then rebatch it with the per replica batch size that is calculated using `global_batch_size // split_batch_by`. The currently supported datasets are as follows: `dataset.batch()` is the last operation on the dataset OR `dataset.apply(map_and_batch)` is the last operation on the dataset OR `dataset.batch().prefetch()` are the last 2 operations on the dataset OR `dataset.apply(map_and_batch).prefetch()` are the last 2 operations. TODO(priyag): Support multi worker / host cases properly by cloning and sharding the dataset on each worker. Current setup will only work in some cases, such as in-graph multi worker GPU case. If the input pipeline has random shuffling (with a different seed on each worker), each worker will see random input from the same overall dataset in each step. Otherwise, each worker will see the same input in each step. Args: dataset: `tf.data.Dataset` that will be used as the input source. input_workers: an `InputWorkers` object. split_batch_by: Optional integer. If present, we "split" each batch of the dataset by `split_batch_by` value. **kwargs: Additional experimental flags. Will be removed in future. """ assert isinstance(input_workers, InputWorkers) if split_batch_by: dataset = batching._RebatchDataset(dataset, split_batch_by) # pylint: disable=protected-access iterators = [] for i, worker in enumerate(input_workers.worker_devices): with ops.device(worker): worker_devices = input_workers.compute_devices_for_worker(i) cloned_dataset = dataset if not context.executing_eagerly(): cloned_dataset = input_ops._clone_dataset(dataset) # pylint: disable=protected-access cloned_dataset = cloned_dataset.with_options( dataset.options()) iterator = _SingleWorkerDatasetIterator( cloned_dataset, worker, worker_devices) iterators.append(iterator) self._element_structure = dataset._element_structure # pylint: disable=protected-access super(DatasetIterator, self).__init__(input_workers, iterators, **kwargs)
def __init__(self, dataset, input_workers, split_batch_by=None, **kwargs): """Make an iterator for the dataset on given devices. If `split_batch_by` is not None, we "split" each batch of the dataset by `split_batch_by` value. To achieve this, we first unbatch the input dataset and then rebatch it with the per replica batch size that is calculated using `global_batch_size // split_batch_by`. The currently supported datasets are as follows: `dataset.batch()` is the last operation on the dataset OR `dataset.apply(map_and_batch)` is the last operation on the dataset OR `dataset.batch().prefetch()` are the last 2 operations on the dataset OR `dataset.apply(map_and_batch).prefetch()` are the last 2 operations. TODO(priyag): Support multi worker / host cases properly by cloning and sharding the dataset on each worker. Current setup will only work in some cases, such as in-graph multi worker GPU case. If the input pipeline has random shuffling (with a different seed on each worker), each worker will see random input from the same overall dataset in each step. Otherwise, each worker will see the same input in each step. Args: dataset: `tf.data.Dataset` that will be used as the input source. input_workers: an `InputWorkers` object. split_batch_by: Optional integer. If present, we "split" each batch of the dataset by `split_batch_by` value. **kwargs: Additional experimental flags. Will be removed in future. """ assert isinstance(input_workers, InputWorkers) if split_batch_by: dataset = batching._RebatchDataset(dataset, split_batch_by) # pylint: disable=protected-access iterators = [] for i, worker in enumerate(input_workers.worker_devices): with ops.device(worker): worker_devices = input_workers.compute_devices_for_worker(i) cloned_dataset = dataset if not context.executing_eagerly(): cloned_dataset = input_ops._clone_dataset(dataset) # pylint: disable=protected-access cloned_dataset = cloned_dataset.with_options(dataset.options()) iterator = _SingleWorkerDatasetIterator(cloned_dataset, worker, worker_devices) iterators.append(iterator) self._element_structure = dataset._element_structure # pylint: disable=protected-access super(DatasetIterator, self).__init__(input_workers, iterators, **kwargs)
def testMultipleVariantTensors(self): ds = dataset_ops.Dataset.range(10) ds = _TestDataset(ds) cloned_ds = input_ops._clone_dataset(ds) self._assert_datasets_equal(ds, cloned_ds)
def testZip(self): ds1 = dataset_ops.Dataset.range(10) ds2 = dataset_ops.Dataset.range(10) ds = dataset_ops.Dataset.zip((ds1, ds2)) cloned_ds = input_ops._clone_dataset(ds) self._assert_datasets_equal(ds, cloned_ds)
def testConcat(self): ds1 = dataset_ops.Dataset.range(10) ds2 = dataset_ops.Dataset.range(10) ds = ds1.concatenate(ds2) cloned_ds = input_ops._clone_dataset(ds) self._assert_datasets_equal(ds, cloned_ds)
def testSimplePipeline(self): ds = dataset_ops.Dataset.range(10).map(math_ops.square) cloned_ds = input_ops._clone_dataset(ds) self._assert_datasets_equal(ds, cloned_ds)
def testOnlySource(self): ds = dataset_ops.Dataset.range(10) cloned_ds = input_ops._clone_dataset(ds) self._assert_datasets_equal(ds, cloned_ds)
def __init__(self, dataset, input_workers, strategy, split_batch_by=None, input_context=None): """Distribute the dataset on all workers. If `split_batch_by` is not None, we "split" each batch of the dataset by `split_batch_by` value. Args: dataset: `tf.data.Dataset` that will be used as the input source. input_workers: an `InputWorkers` object. strategy: a `tf.distribute.Strategy` object, used to run all-reduce to handle last partial batch. split_batch_by: Optional integer. If present, we "split" each batch of the dataset by `split_batch_by` value. input_context: `InputContext` for sharding. Only pass this in for between graph multi-worker cases where there is only one `input_worker`. In these cases, we will shard based on the `input_pipeline_id` and `num_input_pipelines` in the `InputContext`. """ super(DistributedDataset, self).__init__(input_workers=input_workers) # We clone and shard the dataset on each worker. The current setup tries to # shard the dataset by files if possible so that each worker sees a # different subset of files. If that is not possible, will attempt to shard # the final input such that each worker will run the entire preprocessing # pipeline and only receive its own shard of the dataset. if split_batch_by: try: dataset = distribute._RebatchDataset(dataset, split_batch_by) # pylint: disable=protected-access except errors.InvalidArgumentError as e: if "without encountering a batch" in str(e): six.reraise( ValueError, ValueError( "Call the `batch` method on the input Dataset in order to be " "able to split your input across {} replicas.\n Please " "the tf.distribute.Strategy guide. {}".format( split_batch_by, e)), sys.exc_info()[2]) else: raise self._cloned_datasets = [] if input_context: # Between-graph where we rely on the input_context for sharding assert input_workers.num_workers == 1 dataset = input_ops.auto_shard_dataset( # pylint: disable=protected-access dataset, input_context.num_input_pipelines, input_context.input_pipeline_id) self._cloned_datasets.append(dataset) else: for i, worker in enumerate(input_workers.worker_devices): with ops.device(worker): cloned_dataset = dataset if not context.executing_eagerly(): cloned_dataset = input_ops._clone_dataset(dataset) # pylint: disable=protected-access cloned_dataset = cloned_dataset.with_options( dataset.options()) # TODO(b/129506833): Figure out between graph cases cloned_dataset = input_ops.auto_shard_dataset( # pylint: disable=protected-access cloned_dataset, len(input_workers.worker_devices), i) self._cloned_datasets.append(cloned_dataset) self._input_workers = input_workers # TODO(anjalisridhar): Identify if we need to set this property on the # iterator. self._element_structure = dataset._element_structure # pylint: disable=protected-access self._strategy = strategy