def testMultipleBatches(self, drop_remainder): dataset = dataset_ops.Dataset.range(128).batch( 4, drop_remainder=drop_remainder) dataset = dataset.batch(8, drop_remainder=drop_remainder) self.assertEqual( [[8, 4]] if drop_remainder else [[None, None]], [ts.as_list() for ts in nest.flatten(dataset.output_shapes)]) # Each element is a list of 8 elements where each element is a list of 4. expected_output = [ [ [j, j + 1, j + 2, j + 3] # pylint: disable=g-complex-comprehension for j in range(i, i + 32, 4) ] # generates 8 elements for i in range(0, 128, 32) ] self.assertDatasetProduces(dataset, expected_output) rebatched_dataset = batching._RebatchDataset(dataset, 4) self.assertEqual([[2, 4]] if drop_remainder else [[None, None]], [ ts.as_list() for ts in nest.flatten(rebatched_dataset.output_shapes) ]) # Each element is a list of 2 elements where each element is a list of 4. expected_output = [ [ [j, j + 1, j + 2, j + 3] # pylint: disable=g-complex-comprehension for j in range(i, i + 8, 4) ] # generates 2 elements for i in range(0, 128, 8) ] self.assertDatasetProduces(rebatched_dataset, expected_output)
def testPaddedBatch(self, drop_remainder): dataset = dataset_ops.Dataset.range(128).batch(4).padded_batch( 8, padded_shapes=[5], drop_remainder=drop_remainder) rebatched_dataset = batching._RebatchDataset(dataset, num_workers=4) self.assertEqual( [[8, 5]] if drop_remainder else [[None, 5]], [ts.as_list() for ts in nest.flatten(dataset.output_shapes)]) # Each element is a list of 8 elements in which each element is a list of 5 # elements, first four are numbers and the last one is a padded zero. expected_output = [ [ [j, j + 1, j + 2, j + 3, 0] # pylint: disable=g-complex-comprehension for j in range(i, i + 32, 4) ] # generates 8 elements for i in range(0, 128, 32) ] self.assertDatasetProduces(dataset, expected_output) self.assertEqual([[2, 5]] if drop_remainder else [[None, 5]], [ ts.as_list() for ts in nest.flatten(rebatched_dataset.output_shapes) ]) # Each element is a list of 2 elements in which each element is a list of 5 # elements, first four are numbers and the last one is a padded zero. expected_output = [ [ [j, j + 1, j + 2, j + 3, 0] # pylint: disable=g-complex-comprehension for j in range(i, i + 8, 4) ] # generates 2 elements for i in range(0, 128, 8) ] self.assertDatasetProduces(rebatched_dataset, expected_output)
def testParallelInterleaveBatching(self, drop_remainder): dataset = dataset_ops.Dataset.range(2).interleave( lambda _: dataset_ops.Dataset.range(32).batch( # pylint: disable=g-long-lambda 32, drop_remainder=drop_remainder), cycle_length=2, num_parallel_calls=2) self.assertEqual( [[32 if drop_remainder else None]], [ts.as_list() for ts in nest.flatten(dataset.output_shapes)]) # Two elements where each element is range(32) expected_output = [[k for k in range(32)] for _ in range(2)] # pylint: disable=g-complex-comprehension self.assertDatasetProduces(dataset, expected_output) rebatched_dataset = batching._RebatchDataset(dataset, num_workers=4) self.assertEqual([[8 if drop_remainder else None]], [ ts.as_list() for ts in nest.flatten(rebatched_dataset.output_shapes) ]) # List of 4 elements where each element is a list of 8 numbering from 0 to # 31 repeated twice in collated fashion i.e [0...8], [0...8] etc. expected_output = [ [k for k in range(i, i + 8)] # pylint: disable=g-complex-comprehension for i in range(0, 32, 8) # generates 4 elements for _ in range(2) ] self.assertDatasetProduces(rebatched_dataset, expected_output)
def testUnsupportedTransformError(self, drop_remainder): dataset = dataset_ops.Dataset.range(1024).batch( 32, drop_remainder=drop_remainder).apply( scan_ops.scan([0], lambda _, a: ([0], a))) with self.assertRaises(errors.InvalidArgumentError): rebatched_dataset = batching._RebatchDataset(dataset, num_workers=4) next_element = self.getNext(rebatched_dataset) self.evaluate(next_element())
def testUnsupportedTransformError(self, drop_remainder): dataset = dataset_ops.Dataset.range(1024).batch( 32, drop_remainder=drop_remainder).apply( scan_ops.scan([0], lambda _, a: ([0], a))) with self.assertRaises(errors.InvalidArgumentError): rebatched_dataset = batching._RebatchDataset(dataset, num_workers=4) next_element = self.getNext(rebatched_dataset) self.evaluate(next_element())
def testNotDivisibleError(self, drop_remainder): dataset = dataset_ops.Dataset.range(1024).batch( 32, drop_remainder=drop_remainder) with self.assertRaisesRegexp(errors.InvalidArgumentError, "not divisible by"): rebatched_dataset = batching._RebatchDataset(dataset, num_workers=5) next_element = self.getNext(rebatched_dataset) self.evaluate(next_element())
def testNotDivisibleError(self, drop_remainder): dataset = dataset_ops.Dataset.range(1024).batch( 32, drop_remainder=drop_remainder) with self.assertRaisesRegexp(errors.InvalidArgumentError, "not divisible by"): rebatched_dataset = batching._RebatchDataset(dataset, num_workers=5) next_element = self.getNext(rebatched_dataset) self.evaluate(next_element())
def testTupleOutput(self, drop_remainder): dataset = ( dataset_ops.Dataset.range(1024).map(lambda x: (x, x)).batch( 32, drop_remainder=drop_remainder)) rebatched_dataset = batching._RebatchDataset(dataset, num_workers=4) expected_output = [([k for k in range(i, i + 8)], # pylint: disable=g-complex-comprehension [k for k in range(i, i + 8)]) for i in range(0, 1024, 8)] self.assertDatasetProduces(rebatched_dataset, expected_output)
def testNestedDictionaryOutput(self, drop_remainder): dataset = dataset_ops.Dataset.range(1024).map( lambda x: {"a": x, "b": {"c": x}}).batch( 32, drop_remainder=drop_remainder) rebatched_dataset = batching._RebatchDataset(dataset, num_workers=4) expected_output = [{"a": [k for k in range(i, i + 8)], # pylint: disable=g-complex-comprehension "b": {"c": [k for k in range(i, i + 8)]}} for i in range(0, 1024, 8)] self.assertDatasetProduces(rebatched_dataset, expected_output)
def testTupleOutput(self, drop_remainder): dataset = ( dataset_ops.Dataset.range(1024).map(lambda x: (x, x)).batch( 32, drop_remainder=drop_remainder)) rebatched_dataset = batching._RebatchDataset(dataset, num_workers=4) expected_output = [([k for k in range(i, i + 8)], # pylint: disable=g-complex-comprehension [k for k in range(i, i + 8)]) for i in range(0, 1024, 8)] self.assertDatasetProduces(rebatched_dataset, expected_output)
def testNestedDictionaryOutput(self, drop_remainder): dataset = dataset_ops.Dataset.range(1024).map( lambda x: {"a": x, "b": {"c": x}}).batch( 32, drop_remainder=drop_remainder) rebatched_dataset = batching._RebatchDataset(dataset, num_workers=4) expected_output = [{"a": [k for k in range(i, i + 8)], # pylint: disable=g-complex-comprehension "b": {"c": [k for k in range(i, i + 8)]}} for i in range(0, 1024, 8)] self.assertDatasetProduces(rebatched_dataset, expected_output)
def __init__(self, dataset, input_workers, split_batch_by=None, input_context=None, **kwargs): """Distribute the dataset on all workers. If `split_batch_by` is not None, we "split" each batch of the dataset by `split_batch_by` value. Args: dataset: `tf.data.Dataset` that will be used as the input source. input_workers: an `InputWorkers` object. split_batch_by: Optional integer. If present, we "split" each batch of the dataset by `split_batch_by` value. input_context: `InputContext` for sharding. Only pass this in for between graph multi-worker cases where there is only one `input_worker`. In these cases, we will shard based on the `input_pipeline_id` and `num_input_pipelines` in the `InputContext`. **kwargs: Additional experimental flags. Will be removed in future. """ # We clone and shard the dataset on each worker. The current setup tries to # shard the dataset by files if possible so that each worker sees a # different subset of files. If that is not possible, will attempt to shard # the final input such that each worker will run the entire preprocessing # pipeline and only receive its own shard of the dataset. assert isinstance(input_workers, InputWorkers) if split_batch_by: dataset = batching._RebatchDataset(dataset, split_batch_by) # pylint: disable=protected-access self._cloned_datasets = [] if input_context: # Between-graph where we rely on the input_context for sharding assert input_workers.num_workers == 1 dataset = input_ops.auto_shard_dataset( # pylint: disable=protected-access dataset, input_context.num_input_pipelines, input_context.input_pipeline_id) self._cloned_datasets.append(dataset) else: for i, worker in enumerate(input_workers.worker_devices): with ops.device(worker): cloned_dataset = dataset if not context.executing_eagerly(): cloned_dataset = input_ops._clone_dataset(dataset) # pylint: disable=protected-access cloned_dataset = cloned_dataset.with_options( dataset.options()) # TODO(b/129506833): Figure out between graph cases cloned_dataset = input_ops.auto_shard_dataset( # pylint: disable=protected-access cloned_dataset, len(input_workers.worker_devices), i) self._cloned_datasets.append(cloned_dataset) self._input_workers = input_workers # TODO(anjalisridhar): Identify if we need to set this property on the # iterator. self._element_structure = dataset._element_structure # pylint: disable=protected-access self._kwargs = kwargs
def testBasic(self): dataset = dataset_ops.Dataset.range(1024).batch(32, drop_remainder=True) rebatched_dataset = batching._RebatchDataset(dataset, num_workers=4) self.assertEqual( [[32]], [ts.as_list() for ts in nest.flatten(dataset.output_shapes)]) self.assertEqual( [[8]], [ts.as_list() for ts in nest.flatten(rebatched_dataset.output_shapes)]) expected_output = [[k for k in range(i, i + 8)] for i in range(0, 1024, 8)] # pylint: disable=g-complex-comprehension self.assertDatasetProduces(rebatched_dataset, expected_output)
def __init__(self, dataset, input_workers, split_batch_by=None, input_context=None, **kwargs): """Distribute the dataset on all workers. If `split_batch_by` is not None, we "split" each batch of the dataset by `split_batch_by` value. Args: dataset: `tf.data.Dataset` that will be used as the input source. input_workers: an `InputWorkers` object. split_batch_by: Optional integer. If present, we "split" each batch of the dataset by `split_batch_by` value. input_context: `InputContext` for sharding. Only pass this in for between graph multi-worker cases where there is only one `input_worker`. In these cases, we will shard based on the `input_pipeline_id` and `num_input_pipelines` in the `InputContext`. **kwargs: Additional experimental flags. Will be removed in future. """ # We clone and shard the dataset on each worker. The current setup tries to # shard the dataset by files if possible so that each worker sees a # different subset of files. If that is not possible, will attempt to shard # the final input such that each worker will run the entire preprocessing # pipeline and only receive its own shard of the dataset. assert isinstance(input_workers, InputWorkers) if split_batch_by: dataset = batching._RebatchDataset(dataset, split_batch_by) # pylint: disable=protected-access self._cloned_datasets = [] if input_context: # Between-graph where we rely on the input_context for sharding assert input_workers.num_workers == 1 dataset = input_ops.auto_shard_dataset( # pylint: disable=protected-access dataset, input_context.num_input_pipelines, input_context.input_pipeline_id) self._cloned_datasets.append(dataset) else: for i, worker in enumerate(input_workers.worker_devices): with ops.device(worker): cloned_dataset = dataset if not context.executing_eagerly(): cloned_dataset = input_ops._clone_dataset(dataset) # pylint: disable=protected-access cloned_dataset = cloned_dataset.with_options(dataset.options()) # TODO(b/129506833): Figure out between graph cases cloned_dataset = input_ops.auto_shard_dataset( # pylint: disable=protected-access cloned_dataset, len(input_workers.worker_devices), i) self._cloned_datasets.append(cloned_dataset) self._input_workers = input_workers # TODO(anjalisridhar): Identify if we need to set this property on the # iterator. self._element_structure = dataset._element_structure # pylint: disable=protected-access self._kwargs = kwargs
def testFinalPartialBatchOriginal(self, drop_remainder): dataset = dataset_ops.Dataset.range(1032).batch( 32, drop_remainder=drop_remainder) rebatched_dataset = batching._RebatchDataset(dataset, num_workers=4) self.assertEqual( [[32 if drop_remainder else None]], [ts.as_list() for ts in nest.flatten(dataset.output_shapes)]) self.assertEqual( [[8 if drop_remainder else None]], [ts.as_list() for ts in nest.flatten(rebatched_dataset.output_shapes)]) expected_output = [[k for k in range(i, i + 8)] for i in range(0, 1032, 8)] # pylint: disable=g-complex-comprehension self.assertDatasetProduces(rebatched_dataset, expected_output)
def testFinalPartialBatchOriginal(self, drop_remainder): dataset = dataset_ops.Dataset.range(1032).batch( 32, drop_remainder=drop_remainder) rebatched_dataset = batching._RebatchDataset(dataset, num_workers=4) self.assertEqual( [[32 if drop_remainder else None]], [ts.as_list() for ts in _flat_shapes(dataset)]) self.assertEqual( [[8 if drop_remainder else None]], [ts.as_list() for ts in _flat_shapes(rebatched_dataset)]) expected_output = [[k for k in range(i, i + 8)] for i in range(0, 1032, 8)] # pylint: disable=g-complex-comprehension self.assertDatasetProduces(rebatched_dataset, expected_output)
def testZip(self, drop_remainder): dataset1 = dataset_ops.Dataset.range(64).batch( 8, drop_remainder=drop_remainder) dataset2 = dataset_ops.Dataset.range(32).batch( 8, drop_remainder=drop_remainder) dataset = dataset_ops.Dataset.zip((dataset1, dataset2)) rebatched_dataset = batching._RebatchDataset(dataset, num_workers=4) self.assertEqual([[8], [8]] if drop_remainder else [[None], [None]], [ts.as_list() for ts in _flat_shapes(dataset)]) self.assertEqual( [[2], [2]] if drop_remainder else [[None], [None]], [ts.as_list() for ts in _flat_shapes(rebatched_dataset)]) expected_output = [([i, i + 1], [i, i + 1]) for i in range(0, 32, 2)] self.assertDatasetProduces(rebatched_dataset, expected_output)
def testMapAndBatch(self, drop_remainder): dataset = dataset_ops.Dataset.range(1024).apply( batching.map_and_batch( math_ops.square, 32, drop_remainder=drop_remainder)) rebatched_dataset = batching._RebatchDataset(dataset, num_workers=4) self.assertEqual( [[32 if drop_remainder else None]], [ts.as_list() for ts in _flat_shapes(dataset)]) self.assertEqual( [[8 if drop_remainder else None]], [ts.as_list() for ts in _flat_shapes(rebatched_dataset)]) expected_output = [[k**2 for k in range(i, i + 8)] # pylint: disable=g-complex-comprehension for i in range(0, 1024, 8)] self.assertDatasetProduces(rebatched_dataset, expected_output)
def testMapAndBatch(self, drop_remainder): dataset = dataset_ops.Dataset.range(1024).apply( batching.map_and_batch( math_ops.square, 32, drop_remainder=drop_remainder)) rebatched_dataset = batching._RebatchDataset(dataset, num_workers=4) self.assertEqual( [[32 if drop_remainder else None]], [ts.as_list() for ts in nest.flatten(dataset.output_shapes)]) self.assertEqual( [[8 if drop_remainder else None]], [ts.as_list() for ts in nest.flatten(rebatched_dataset.output_shapes)]) expected_output = [[k**2 for k in range(i, i + 8)] # pylint: disable=g-complex-comprehension for i in range(0, 1024, 8)] self.assertDatasetProduces(rebatched_dataset, expected_output)
def testZip(self, drop_remainder): dataset1 = dataset_ops.Dataset.range(64).batch( 8, drop_remainder=drop_remainder) dataset2 = dataset_ops.Dataset.range(32).batch( 8, drop_remainder=drop_remainder) dataset = dataset_ops.Dataset.zip((dataset1, dataset2)) rebatched_dataset = batching._RebatchDataset(dataset, num_workers=4) self.assertEqual( [[8], [8]] if drop_remainder else [[None], [None]], [ts.as_list() for ts in _flat_shapes(dataset)]) self.assertEqual( [[2], [2]] if drop_remainder else [[None], [None]], [ts.as_list() for ts in _flat_shapes(rebatched_dataset)]) expected_output = [([i, i + 1], [i, i + 1]) for i in range(0, 32, 2)] self.assertDatasetProduces(rebatched_dataset, expected_output)
def testFinalPartialBatchAfterRebatch(self, drop_remainder): dataset = dataset_ops.Dataset.range(34).batch( 32, drop_remainder=drop_remainder) rebatched_dataset = batching._RebatchDataset(dataset, num_workers=4) self.assertEqual([[32 if drop_remainder else None]], [ts.as_list() for ts in _flat_shapes(dataset)]) self.assertEqual( [[8 if drop_remainder else None]], [ts.as_list() for ts in _flat_shapes(rebatched_dataset)]) expected_output = [[k for k in range(i, i + 8)] for i in range(0, 32, 8)] # pylint: disable=g-complex-comprehension if not drop_remainder: expected_output += [[32, 33]] self.assertDatasetProduces(rebatched_dataset, expected_output)
def testConcatenate(self, drop_remainder): dataset1 = dataset_ops.Dataset.range(64).batch( 8, drop_remainder=drop_remainder) dataset2 = dataset_ops.Dataset.range(32).batch( 8, drop_remainder=drop_remainder) dataset = dataset1.concatenate(dataset2) rebatched_dataset = batching._RebatchDataset(dataset, num_workers=4) self.assertEqual([[8 if drop_remainder else None]], [ts.as_list() for ts in _flat_shapes(dataset)]) self.assertEqual( [[2 if drop_remainder else None]], [ts.as_list() for ts in _flat_shapes(rebatched_dataset)]) expected_output = ([[i, i + 1] for i in range(0, 64, 2)] + [[i, i + 1] for i in range(0, 32, 2)]) self.assertDatasetProduces(rebatched_dataset, expected_output)
def testConcatenateDifferentShapes(self, drop_remainder): dataset1 = dataset_ops.Dataset.range(64).batch( 16, drop_remainder=drop_remainder) dataset2 = dataset_ops.Dataset.range(32).batch( 8, drop_remainder=drop_remainder) dataset = dataset1.concatenate(dataset2) rebatched_dataset = batching._RebatchDataset(dataset, num_workers=4) self.assertEqual( [[None]], [ts.as_list() for ts in _flat_shapes(dataset)]) self.assertEqual( [[None]], [ts.as_list() for ts in _flat_shapes(rebatched_dataset)]) expected_output = ([[i, i + 1, i + 2, i + 3] for i in range(0, 64, 4)] + [[i, i + 1] for i in range(0, 32, 2)]) self.assertDatasetProduces(rebatched_dataset, expected_output)
def testConcatenateDifferentShapes(self, drop_remainder): dataset1 = dataset_ops.Dataset.range(64).batch( 16, drop_remainder=drop_remainder) dataset2 = dataset_ops.Dataset.range(32).batch( 8, drop_remainder=drop_remainder) dataset = dataset1.concatenate(dataset2) rebatched_dataset = batching._RebatchDataset(dataset, num_workers=4) self.assertEqual( [[None]], [ts.as_list() for ts in nest.flatten(dataset.output_shapes)]) self.assertEqual( [[None]], [ts.as_list() for ts in nest.flatten(rebatched_dataset.output_shapes)]) expected_output = ([[i, i + 1, i + 2, i + 3] for i in range(0, 64, 4)] + [[i, i + 1] for i in range(0, 32, 2)]) self.assertDatasetProduces(rebatched_dataset, expected_output)
def testBasic(self): dataset = dataset_ops.Dataset.range(1024).batch(32, drop_remainder=True) rebatched_dataset = batching._RebatchDataset(dataset, num_workers=4) self.assertEqual( [[32]], [ts.as_list() for ts in nest.flatten(dataset.output_shapes)]) self.assertEqual([[8]], [ ts.as_list() for ts in nest.flatten(rebatched_dataset.output_shapes) ]) expected_output = [[k for k in range(i, i + 8)] for i in range(0, 1024, 8)] # pylint: disable=g-complex-comprehension self.assertDatasetProduces(rebatched_dataset, expected_output)
def __init__(self, dataset, input_workers, split_batch_by=None, **kwargs): """Make an iterator for the dataset on given devices. If `split_batch_by` is not None, we "split" each batch of the dataset by `split_batch_by` value. To achieve this, we first unbatch the input dataset and then rebatch it with the per replica batch size that is calculated using `global_batch_size // split_batch_by`. The currently supported datasets are as follows: `dataset.batch()` is the last operation on the dataset OR `dataset.apply(map_and_batch)` is the last operation on the dataset OR `dataset.batch().prefetch()` are the last 2 operations on the dataset OR `dataset.apply(map_and_batch).prefetch()` are the last 2 operations. TODO(priyag): Support multi worker / host cases properly by cloning and sharding the dataset on each worker. Current setup will only work in some cases, such as in-graph multi worker GPU case. If the input pipeline has random shuffling (with a different seed on each worker), each worker will see random input from the same overall dataset in each step. Otherwise, each worker will see the same input in each step. Args: dataset: `tf.data.Dataset` that will be used as the input source. input_workers: an `InputWorkers` object. split_batch_by: Optional integer. If present, we "split" each batch of the dataset by `split_batch_by` value. **kwargs: Additional experimental flags. Will be removed in future. """ assert isinstance(input_workers, InputWorkers) if split_batch_by: dataset = batching._RebatchDataset(dataset, split_batch_by) # pylint: disable=protected-access iterators = [] for i, worker in enumerate(input_workers.worker_devices): with ops.device(worker): worker_devices = input_workers.compute_devices_for_worker(i) cloned_dataset = dataset if not context.executing_eagerly(): cloned_dataset = input_ops._clone_dataset(dataset) # pylint: disable=protected-access cloned_dataset = cloned_dataset.with_options( dataset.options()) iterator = _SingleWorkerDatasetIterator( cloned_dataset, worker, worker_devices) iterators.append(iterator) self._element_structure = dataset._element_structure # pylint: disable=protected-access super(DatasetIterator, self).__init__(input_workers, iterators, **kwargs)
def __init__(self, dataset, input_workers, split_batch_by=None, **kwargs): """Make an iterator for the dataset on given devices. If `split_batch_by` is not None, we "split" each batch of the dataset by `split_batch_by` value. To achieve this, we first unbatch the input dataset and then rebatch it with the per replica batch size that is calculated using `global_batch_size // split_batch_by`. The currently supported datasets are as follows: `dataset.batch()` is the last operation on the dataset OR `dataset.apply(map_and_batch)` is the last operation on the dataset OR `dataset.batch().prefetch()` are the last 2 operations on the dataset OR `dataset.apply(map_and_batch).prefetch()` are the last 2 operations. We clone and shard the dataset on each worker. The current setup tries to shard the dataset by files if possible so that each worker sees a different subset of files. If that is not possible, will attempt to shard the final input such that each worker will run the entire preprocessing pipeline and only receive its own shard of the dataset. Args: dataset: `tf.data.Dataset` that will be used as the input source. input_workers: an `InputWorkers` object. split_batch_by: Optional integer. If present, we "split" each batch of the dataset by `split_batch_by` value. **kwargs: Additional experimental flags. Will be removed in future. """ assert isinstance(input_workers, InputWorkers) if split_batch_by: dataset = batching._RebatchDataset(dataset, split_batch_by) # pylint: disable=protected-access iterators = [] for i, worker in enumerate(input_workers.worker_devices): with ops.device(worker): worker_devices = input_workers.compute_devices_for_worker(i) cloned_dataset = dataset if not context.executing_eagerly(): cloned_dataset = input_ops._clone_dataset(dataset) # pylint: disable=protected-access cloned_dataset = cloned_dataset.with_options(dataset.options()) # TODO(b/129506833): Figure out between graph cases cloned_dataset = input_ops.auto_shard_dataset( # pylint: disable=protected-access cloned_dataset, len(input_workers.worker_devices), i) iterator = _SingleWorkerDatasetIterator(cloned_dataset, worker, worker_devices) iterators.append(iterator) self._element_structure = dataset._element_structure # pylint: disable=protected-access super(DatasetIterator, self).__init__(input_workers, iterators, **kwargs)
def testZipDifferentShapes(self, drop_remainder): dataset1 = dataset_ops.Dataset.range(64).batch( 16, drop_remainder=drop_remainder) dataset2 = dataset_ops.Dataset.range(32).batch( 8, drop_remainder=drop_remainder) dataset = dataset_ops.Dataset.zip((dataset1, dataset2)) rebatched_dataset = batching._RebatchDataset(dataset, num_workers=4) self.assertEqual( [[16], [8]] if drop_remainder else [[None], [None]], [ts.as_list() for ts in nest.flatten(dataset.output_shapes)]) self.assertEqual( [[4], [2]] if drop_remainder else [[None], [None]], [ts.as_list() for ts in nest.flatten(rebatched_dataset.output_shapes)]) expected_output = [([2 * i, 2 * i + 1, 2 * i + 2, 2 * i + 3], [i, i + 1]) for i in range(0, 32, 2)] self.assertDatasetProduces(rebatched_dataset, expected_output)
def __init__(self, dataset, input_workers, split_batch_by=None, **kwargs): """Make an iterator for the dataset on given devices. If `split_batch_by` is not None, we "split" each batch of the dataset by `split_batch_by` value. To achieve this, we first unbatch the input dataset and then rebatch it with the per replica batch size that is calculated using `global_batch_size // split_batch_by`. The currently supported datasets are as follows: `dataset.batch()` is the last operation on the dataset OR `dataset.apply(map_and_batch)` is the last operation on the dataset OR `dataset.batch().prefetch()` are the last 2 operations on the dataset OR `dataset.apply(map_and_batch).prefetch()` are the last 2 operations. TODO(priyag): Support multi worker / host cases properly by cloning and sharding the dataset on each worker. Current setup will only work in some cases, such as in-graph multi worker GPU case. If the input pipeline has random shuffling (with a different seed on each worker), each worker will see random input from the same overall dataset in each step. Otherwise, each worker will see the same input in each step. Args: dataset: `tf.data.Dataset` that will be used as the input source. input_workers: an `InputWorkers` object. split_batch_by: Optional integer. If present, we "split" each batch of the dataset by `split_batch_by` value. **kwargs: Additional experimental flags. Will be removed in future. """ assert isinstance(input_workers, InputWorkers) if split_batch_by: dataset = batching._RebatchDataset(dataset, split_batch_by) # pylint: disable=protected-access iterators = [] for i, worker in enumerate(input_workers.worker_devices): with ops.device(worker): worker_devices = input_workers.compute_devices_for_worker(i) cloned_dataset = dataset if not context.executing_eagerly(): cloned_dataset = input_ops._clone_dataset(dataset) # pylint: disable=protected-access cloned_dataset = cloned_dataset.with_options(dataset.options()) iterator = _SingleWorkerDatasetIterator(cloned_dataset, worker, worker_devices) iterators.append(iterator) self._element_structure = dataset._element_structure # pylint: disable=protected-access super(DatasetIterator, self).__init__(input_workers, iterators, **kwargs)
def testFlatMapBatching(self, drop_remainder): dataset = dataset_ops.Dataset.range( 2).flat_map(lambda _: dataset_ops.Dataset.range(32).batch( # pylint: disable=g-long-lambda 32, drop_remainder=drop_remainder)) self.assertEqual( [[32 if drop_remainder else None]], [ts.as_list() for ts in nest.flatten(dataset.output_shapes)]) # Two elements where each element is range(32) expected_output = [[k for k in range(32)] for _ in range(2)] # pylint: disable=g-complex-comprehension self.assertDatasetProduces(dataset, expected_output) rebatched_dataset = batching._RebatchDataset(dataset, num_workers=4) self.assertEqual( [[8 if drop_remainder else None]], [ts.as_list() for ts in nest.flatten(rebatched_dataset.output_shapes)]) # Two elements where each element is a list of 4 elements where each element # is a list of 8. expected_output = [[k for k in range(i, i + 8)] # pylint: disable=g-complex-comprehension for _ in range(2) for i in range(0, 32, 8)] # generates 4 elements self.assertDatasetProduces(rebatched_dataset, expected_output)
def testFlatMapBatching(self, drop_remainder): dataset = dataset_ops.Dataset.range( 2).flat_map(lambda _: dataset_ops.Dataset.range(32).batch( # pylint: disable=g-long-lambda 32, drop_remainder=drop_remainder)) self.assertEqual( [[32 if drop_remainder else None]], [ts.as_list() for ts in _flat_shapes(dataset)]) # Two elements where each element is range(32) expected_output = [[k for k in range(32)] for _ in range(2)] # pylint: disable=g-complex-comprehension self.assertDatasetProduces(dataset, expected_output) rebatched_dataset = batching._RebatchDataset(dataset, num_workers=4) self.assertEqual( [[8 if drop_remainder else None]], [ts.as_list() for ts in _flat_shapes(rebatched_dataset)]) # Two elements where each element is a list of 4 elements where each element # is a list of 8. expected_output = [[k for k in range(i, i + 8)] # pylint: disable=g-complex-comprehension for _ in range(2) for i in range(0, 32, 8)] # generates 4 elements self.assertDatasetProduces(rebatched_dataset, expected_output)
def testInterleaveBatching(self, drop_remainder): dataset = dataset_ops.Dataset.range( 2).interleave(lambda _: dataset_ops.Dataset.range(32).batch( # pylint: disable=g-long-lambda 32, drop_remainder=drop_remainder), cycle_length=2) self.assertEqual( [[32 if drop_remainder else None]], [ts.as_list() for ts in _flat_shapes(dataset)]) # Two elements where each element is range(32) expected_output = [[k for k in range(32)] for _ in range(2)] # pylint: disable=g-complex-comprehension self.assertDatasetProduces(dataset, expected_output) rebatched_dataset = batching._RebatchDataset(dataset, num_workers=4) self.assertEqual( [[8 if drop_remainder else None]], [ts.as_list() for ts in _flat_shapes(rebatched_dataset)]) # List of 4 elements where each element is a list of 8 numbering from 0 to # 31 repeated twice. expected_output = [[k for k in range(i, i + 8)] # pylint: disable=g-complex-comprehension for i in range(0, 32, 8) # generates 4 elements for _ in range(2)] self.assertDatasetProduces(rebatched_dataset, expected_output)
def testMultipleBatches(self, drop_remainder): dataset = dataset_ops.Dataset.range(128).batch( 4, drop_remainder=drop_remainder) dataset = dataset.batch(8, drop_remainder=drop_remainder) self.assertEqual( [[8, 4]] if drop_remainder else [[None, None]], [ts.as_list() for ts in _flat_shapes(dataset)]) # Each element is a list of 8 elements where each element is a list of 4. expected_output = [[[j, j + 1, j + 2, j + 3] # pylint: disable=g-complex-comprehension for j in range(i, i + 32, 4)] # generates 8 elements for i in range(0, 128, 32)] self.assertDatasetProduces(dataset, expected_output) rebatched_dataset = batching._RebatchDataset(dataset, 4) self.assertEqual( [[2, 4]] if drop_remainder else [[None, None]], [ts.as_list() for ts in _flat_shapes(rebatched_dataset)]) # Each element is a list of 2 elements where each element is a list of 4. expected_output = [[[j, j + 1, j + 2, j + 3] # pylint: disable=g-complex-comprehension for j in range(i, i + 8, 4)] # generates 2 elements for i in range(0, 128, 8)] self.assertDatasetProduces(rebatched_dataset, expected_output)
def testPaddedBatch(self, drop_remainder): dataset = dataset_ops.Dataset.range(128).batch(4).padded_batch( 8, padded_shapes=[5], drop_remainder=drop_remainder) rebatched_dataset = batching._RebatchDataset(dataset, num_workers=4) self.assertEqual( [[8, 5]] if drop_remainder else [[None, 5]], [ts.as_list() for ts in _flat_shapes(dataset)]) # Each element is a list of 8 elements in which each element is a list of 5 # elements, first four are numbers and the last one is a padded zero. expected_output = [[[j, j + 1, j + 2, j + 3, 0] # pylint: disable=g-complex-comprehension for j in range(i, i + 32, 4)] # generates 8 elements for i in range(0, 128, 32)] self.assertDatasetProduces(dataset, expected_output) self.assertEqual( [[2, 5]] if drop_remainder else [[None, 5]], [ts.as_list() for ts in _flat_shapes(rebatched_dataset)]) # Each element is a list of 2 elements in which each element is a list of 5 # elements, first four are numbers and the last one is a padded zero. expected_output = [[[j, j + 1, j + 2, j + 3, 0] # pylint: disable=g-complex-comprehension for j in range(i, i + 8, 4)] # generates 2 elements for i in range(0, 128, 8)] self.assertDatasetProduces(rebatched_dataset, expected_output)
def build_dataset(num_elements, batch_size): return batching._RebatchDataset( dataset_ops.Dataset.range(num_elements).batch( 4 * batch_size, drop_remainder=True), num_workers=4)
def testUnknownBatchSizeError(self): dataset = dataset_ops.Dataset.range(1024).batch(32) with self.assertRaisesRegexp(ValueError, "unknown batch size datasets"): batching._RebatchDataset(dataset, num_workers=4)
def testNotDivisibleError(self): dataset = dataset_ops.Dataset.range(1024).batch(32, drop_remainder=True) with self.assertRaisesRegexp(ValueError, "not divisible by"): batching._RebatchDataset(dataset, num_workers=5)
def testNotDivisibleError(self, drop_remainder): # TODO(rohanj): This should fail even with drop_remainder=False, by adding # an assertion in the mutated graph. dataset = dataset_ops.Dataset.range(1024).batch(32, drop_remainder=True) with self.assertRaisesRegexp(ValueError, "not divisible by"): batching._RebatchDataset(dataset, num_workers=5)
def testScalarInputError(self, _): dataset = dataset_ops.Dataset.range(1024) with self.assertRaisesRegexp(ValueError, "at least one dimension"): batching._RebatchDataset(dataset, num_workers=4)
def testNotDivisibleError(self): dataset = dataset_ops.Dataset.range(1024).batch(32, drop_remainder=True) with self.assertRaisesRegexp(ValueError, "not divisible by"): batching._RebatchDataset(dataset, num_workers=5)
def testUnknownBatchSizeError(self): dataset = dataset_ops.Dataset.range(1024).batch(32) with self.assertRaisesRegexp(ValueError, "unknown batch size datasets"): batching._RebatchDataset(dataset, num_workers=4)
def testScalarInputError(self, _): dataset = dataset_ops.Dataset.range(1024) with self.assertRaisesRegexp(ValueError, "at least one dimension"): batching._RebatchDataset(dataset, num_workers=4)