def __init__(self, dataset, input_workers, strategy, split_batch_by=None, input_context=None, **kwargs): """Distribute the dataset on all workers. If `split_batch_by` is not None, we "split" each batch of the dataset by `split_batch_by` value. Args: dataset: `tf.data.Dataset` that will be used as the input source. input_workers: an `InputWorkers` object. strategy: a `tf.distribute.Strategy` object, used to run all-reduce to handle last partial batch. split_batch_by: Optional integer. If present, we "split" each batch of the dataset by `split_batch_by` value. input_context: `InputContext` for sharding. Only pass this in for between graph multi-worker cases where there is only one `input_worker`. In these cases, we will shard based on the `input_pipeline_id` and `num_input_pipelines` in the `InputContext`. **kwargs: Additional experimental flags. Will be removed in future. """ # We clone and shard the dataset on each worker. The current setup tries to # shard the dataset by files if possible so that each worker sees a # different subset of files. If that is not possible, will attempt to shard # the final input such that each worker will run the entire preprocessing # pipeline and only receive its own shard of the dataset. assert isinstance(input_workers, InputWorkers) if split_batch_by: dataset = distribute._RebatchDataset(dataset, split_batch_by) # pylint: disable=protected-access self._cloned_datasets = [] if input_context: # Between-graph where we rely on the input_context for sharding assert input_workers.num_workers == 1 dataset = input_ops.auto_shard_dataset( # pylint: disable=protected-access dataset, input_context.num_input_pipelines, input_context.input_pipeline_id) self._cloned_datasets.append(dataset) else: for i, worker in enumerate(input_workers.worker_devices): with ops.device(worker): cloned_dataset = dataset if not context.executing_eagerly(): cloned_dataset = input_ops._clone_dataset(dataset) # pylint: disable=protected-access cloned_dataset = cloned_dataset.with_options(dataset.options()) # TODO(b/129506833): Figure out between graph cases cloned_dataset = input_ops.auto_shard_dataset( # pylint: disable=protected-access cloned_dataset, len(input_workers.worker_devices), i) self._cloned_datasets.append(cloned_dataset) self._input_workers = input_workers # TODO(anjalisridhar): Identify if we need to set this property on the # iterator. self._element_structure = dataset._element_structure # pylint: disable=protected-access self._strategy = strategy self._kwargs = kwargs
def testNotDivisibleError(self, drop_remainder): dataset = dataset_ops.Dataset.range(1024).batch( 32, drop_remainder=drop_remainder) with self.assertRaisesRegexp(errors.InvalidArgumentError, "not divisible by"): rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=5) next_element = self.getNext(rebatched_dataset) self.evaluate(next_element())
def testUnsupportedTransformError(self, drop_remainder): dataset = dataset_ops.Dataset.range(1024).batch( 32, drop_remainder=drop_remainder).apply( scan_ops.scan([0], lambda _, a: ([0], a))) with self.assertRaises(errors.InvalidArgumentError): rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=4) next_element = self.getNext(rebatched_dataset) self.evaluate(next_element())
def testBatchSizesDontMatch(self): dataset = dataset_ops.Dataset.from_tensors( (np.arange(10), np.arange(5))) with self.assertRaisesRegexp(errors.InvalidArgumentError, "Cannot use rebatching fallback"): rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=5) next_element = self.getNext(rebatched_dataset) self.evaluate(next_element())
def testNotDivisibleError(self, drop_remainder): dataset = dataset_ops.Dataset.range(1024).batch( 32, drop_remainder=drop_remainder) with self.assertRaisesRegexp(errors.InvalidArgumentError, "not divisible by"): rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=5) next_element = self.getNext(rebatched_dataset) self.evaluate(next_element())
def __init__(self, dataset, input_workers, split_batch_by=None, input_context=None, **kwargs): """Distribute the dataset on all workers. If `split_batch_by` is not None, we "split" each batch of the dataset by `split_batch_by` value. Args: dataset: `tf.data.Dataset` that will be used as the input source. input_workers: an `InputWorkers` object. split_batch_by: Optional integer. If present, we "split" each batch of the dataset by `split_batch_by` value. input_context: `InputContext` for sharding. Only pass this in for between graph multi-worker cases where there is only one `input_worker`. In these cases, we will shard based on the `input_pipeline_id` and `num_input_pipelines` in the `InputContext`. **kwargs: Additional experimental flags. Will be removed in future. """ # We clone and shard the dataset on each worker. The current setup tries to # shard the dataset by files if possible so that each worker sees a # different subset of files. If that is not possible, will attempt to shard # the final input such that each worker will run the entire preprocessing # pipeline and only receive its own shard of the dataset. assert isinstance(input_workers, InputWorkers) if split_batch_by: dataset = distribute._RebatchDataset(dataset, split_batch_by) # pylint: disable=protected-access self._cloned_datasets = [] if input_context: # Between-graph where we rely on the input_context for sharding assert input_workers.num_workers == 1 dataset = input_ops.auto_shard_dataset( # pylint: disable=protected-access dataset, input_context.num_input_pipelines, input_context.input_pipeline_id) self._cloned_datasets.append(dataset) else: for i, worker in enumerate(input_workers.worker_devices): with ops.device(worker): cloned_dataset = dataset if not context.executing_eagerly(): cloned_dataset = input_ops._clone_dataset(dataset) # pylint: disable=protected-access cloned_dataset = cloned_dataset.with_options( dataset.options()) # TODO(b/129506833): Figure out between graph cases cloned_dataset = input_ops.auto_shard_dataset( # pylint: disable=protected-access cloned_dataset, len(input_workers.worker_devices), i) self._cloned_datasets.append(cloned_dataset) self._input_workers = input_workers # TODO(anjalisridhar): Identify if we need to set this property on the # iterator. self._element_structure = dataset._element_structure # pylint: disable=protected-access self._kwargs = kwargs
def testNestedDictionaryOutput(self, drop_remainder): dataset = dataset_ops.Dataset.range(1024).map( lambda x: {"a": x, "b": {"c": x}}).batch( 32, drop_remainder=drop_remainder) rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=4) expected_output = [{"a": [k for k in range(i, i + 8)], # pylint: disable=g-complex-comprehension "b": {"c": [k for k in range(i, i + 8)]}} for i in range(0, 1024, 8)] self.assertDatasetProduces(rebatched_dataset, expected_output)
def testUnsupportedTransformError(self, drop_remainder): dataset = dataset_ops.Dataset.range(1024).batch( 32, drop_remainder=drop_remainder).apply( scan_ops.scan([0], lambda _, a: ([0], a))) with self.assertRaises(errors.InvalidArgumentError): rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=4) next_element = self.getNext(rebatched_dataset) self.evaluate(next_element())
def testPartialBatchWithDropRemainder(self): dataset = dataset_ops.Dataset.range(5).batch(4, drop_remainder=False) rebatched_dataset = distribute._RebatchDataset( dataset, batch_sizes=[2, 2], drop_remainder=True) expected_shapes = [[2]] self.assertEqual(expected_shapes, _flat_shapes(rebatched_dataset)) expected_output = [[0, 1], [2, 3]] self.assertDatasetProduces(rebatched_dataset, expected_output)
def testUnsupportedTransformError(self): dataset = dataset_ops.Dataset.range(1024).batch(32).apply( sleep.sleep(10)) with self.assertRaises(errors.InvalidArgumentError): rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4, use_fallback=False) next_element = self.getNext(rebatched_dataset) self.evaluate(next_element())
def testMapAndBatch(self): dataset = dataset_ops.Dataset.range(1024).apply( batching.map_and_batch(math_ops.square, 32)) rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4) self.assertEqual([[None]], [ts.as_list() for ts in _flat_shapes(rebatched_dataset)]) expected_output = [[k**2 for k in range(i, i + 8)] # pylint: disable=g-complex-comprehension for i in range(0, 1024, 8)] self.assertDatasetProduces(rebatched_dataset, expected_output)
def testUnsupportedTransformInFlatMapError(self): dataset = dataset_ops.Dataset.range(2).flat_map( lambda _: dataset_ops.Dataset.range(32).batch( # pylint: disable=g-long-lambda 32).apply(sleep.sleep(10))) with self.assertRaises(errors.InvalidArgumentError): rebatched_dataset = distribute._RebatchDataset( dataset, num_replicas=4, use_fallback=False) next_element = self.getNext(rebatched_dataset) self.evaluate(next_element())
def testTupleOutput(self, drop_remainder): dataset = ( dataset_ops.Dataset.range(1024).map(lambda x: (x, x)).batch( 32, drop_remainder=drop_remainder)) rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=4) expected_output = [([k for k in range(i, i + 8)], # pylint: disable=g-complex-comprehension [k for k in range(i, i + 8)]) for i in range(0, 1024, 8)] self.assertDatasetProduces(rebatched_dataset, expected_output)
def testZip(self): dataset1 = dataset_ops.Dataset.range(64).batch(8) dataset2 = dataset_ops.Dataset.range(32).batch(8) dataset = dataset_ops.Dataset.zip((dataset1, dataset2)) rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4) self.assertEqual([[None], [None]], [ts.as_list() for ts in _flat_shapes(rebatched_dataset)]) expected_output = [([i, i + 1], [i, i + 1]) for i in range(0, 32, 2)] self.assertDatasetProduces(rebatched_dataset, expected_output)
def testUseLegacyRebatchWithDataSharding(self, sharding_policy, with_prefetch): # This test simulates a distributed environment with 3 workers, each with # 1 replica. dataset = dataset_ops.Dataset.range(8) dataset = dataset.batch(4) options = options_lib.Options() options.experimental_distribute.auto_shard_policy = sharding_policy dataset = dataset.with_options(options) # We expect the auto-shard rewrite to rewrite RebatchDatasetV2 to # RebatchDataset(V1) for correctness reasons. This will modify the output # of the dataset. worker_a_dataset = distribute._RebatchDataset(dataset, batch_sizes=[2, 1, 1]) if with_prefetch: worker_a_dataset = worker_a_dataset.prefetch(1) worker_a_dataset = distribute._AutoShardDataset(worker_a_dataset, 3, 0, num_replicas=3) expected = [[0, 1], [4, 5]] self.assertDatasetProduces(worker_a_dataset, expected) worker_b_dataset = distribute._RebatchDataset(dataset, batch_sizes=[1, 1, 2]) if with_prefetch: worker_b_dataset = worker_b_dataset.prefetch(1) worker_b_dataset = distribute._AutoShardDataset(worker_b_dataset, 3, 1, num_replicas=3) expected = [[2, 3], [6, 7]] self.assertDatasetProduces(worker_b_dataset, expected) worker_c_dataset = distribute._RebatchDataset(dataset, batch_sizes=[1, 2, 1]) if with_prefetch: worker_c_dataset = worker_c_dataset.prefetch(1) worker_c_dataset = distribute._AutoShardDataset(worker_c_dataset, 3, 2, num_replicas=3) expected = [[], []] self.assertDatasetProduces(worker_c_dataset, expected)
def testWithUnknownBatchDim(self): dataset = dataset_ops.Dataset.range(1024).batch( 32, drop_remainder=False).apply(sleep.sleep(10)) with self.assertRaisesRegexp(errors.InvalidArgumentError, "Cannot use rebatching fallback"): rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=4) next_element = self.getNext(rebatched_dataset) self.evaluate(next_element())
def testConcatenateDifferentShapes(self): dataset1 = dataset_ops.Dataset.range(64).batch(16) dataset2 = dataset_ops.Dataset.range(32).batch(8) dataset = dataset1.concatenate(dataset2) rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4) self.assertEqual([[None]], [ts.as_list() for ts in _flat_shapes(rebatched_dataset)]) expected_output = ([[i, i + 1, i + 2, i + 3] for i in range(0, 64, 4)] + [[i, i + 1] for i in range(0, 32, 2)]) self.assertDatasetProduces(rebatched_dataset, expected_output)
def testScalarBatchSizeInput(self, drop_remainder): dataset = dataset_ops.Dataset.range(8).batch(4, drop_remainder=True) rebatched_dataset = distribute._RebatchDataset( dataset, batch_sizes=2, drop_remainder=drop_remainder) expected_shapes = [[2]] self.assertEqual(expected_shapes, _flat_shapes(rebatched_dataset)) expected_output = [[0, 1], [2, 3], [4, 5], [6, 7]] self.assertDatasetProduces(rebatched_dataset, expected_output)
def testCanHandleUnknownRank(self): dataset = dataset_ops.Dataset.from_tensors("xxx") # decode_image results in a tensor of completely unknown shape (i.e. unknown # rank) dataset = dataset.map(image_ops.decode_image) self.assertEqual([tensor_shape.TensorShape(None)], _flat_shapes(dataset)) rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4) # Note that we are just testing the dataset shapes, not the actual output. self.assertEqual([tensor_shape.TensorShape(None)], _flat_shapes(rebatched_dataset))
def testCanHandleUnknownDims(self): dataset = dataset_ops.Dataset.range(1000) dataset = dataset.batch(10, drop_remainder=False) dataset = dataset.batch(10, drop_remainder=False) self.assertEqual([[None, None]], [ts.as_list() for ts in _flat_shapes(dataset)]) rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4) # Note that we are just testing the dataset shapes, not the actual output. self.assertEqual([[None, None]], [ts.as_list() for ts in _flat_shapes(rebatched_dataset)])
def testTupleOutput(self, drop_remainder): dataset = (dataset_ops.Dataset.range(1024).map(lambda x: (x, x)).batch( 32, drop_remainder=drop_remainder)) rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=4) expected_output = [ ( [k for k in range(i, i + 8)], # pylint: disable=g-complex-comprehension [k for k in range(i, i + 8)]) for i in range(0, 1024, 8) ] self.assertDatasetProduces(rebatched_dataset, expected_output)
def testWithNoBatchDataset(self): dataset = dataset_ops.Dataset.from_tensor_slices( [[k for k in range(i, i + 32)] for i in range(0, 1024, 32)]) # pylint: disable=g-complex-comprehension rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=4) self.assertEqual([[32]], [ts.as_list() for ts in _flat_shapes(dataset)]) self.assertEqual([[8]], [ts.as_list() for ts in _flat_shapes(rebatched_dataset)]) expected_output = [[k for k in range(i, i + 8)] for i in range(0, 1024, 8)] # pylint: disable=g-complex-comprehension self.assertDatasetProduces(rebatched_dataset, expected_output)
def testWithUnhandledTransformation(self): dataset = dataset_ops.Dataset.range(1024).batch( 32, drop_remainder=True).apply(sleep.sleep(10)) rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=4) self.assertEqual([[32]], [ts.as_list() for ts in _flat_shapes(dataset)]) self.assertEqual([[8]], [ts.as_list() for ts in _flat_shapes(rebatched_dataset)]) expected_output = [[k for k in range(i, i + 8)] for i in range(0, 1024, 8)] # pylint: disable=g-complex-comprehension self.assertDatasetProduces(rebatched_dataset, expected_output)
def testScanAfterBatch(self, drop_remainder): dataset = dataset_ops.Dataset.range(40).batch(10).apply( scan_ops.scan(np.int64(2), lambda state, value: (state, value * state))) dataset = distribute._RebatchDataset(dataset, num_workers=2) self.assertEqual([[None]], [ts.as_list() for ts in _flat_shapes(dataset)]) expected_output = [[i * 2 for i in range(j * 5, (j + 1) * 5)] for j in range(8)] # pylint: disable=g-complex-comprehension self.assertDatasetProduces(dataset, expected_output)
def testBatchSizeIndivisibleByNumWorkers(self): # This doesn't work; reshape requires tensor shape to be exactly divisible # by the second dim. dataset = dataset_ops.Dataset.range(64).batch( 32, drop_remainder=True).apply(sleep.sleep(10)) with self.assertRaisesRegexp(errors.InvalidArgumentError, "Cannot use rebatching fallback"): rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=5) next_element = self.getNext(rebatched_dataset) self.evaluate(next_element())
def testEmptyLastSplits(self, drop_remainder): dataset = dataset_ops.Dataset.range(8).batch(4, drop_remainder=True) rebatched_dataset = distribute._RebatchDataset( dataset, batch_sizes=[1, 0], drop_remainder=drop_remainder) expected_shapes = [[None]] self.assertEqual(expected_shapes, _flat_shapes(rebatched_dataset)) expected_output = [[0], [], [1], [], [2], [], [3], [], [4], [], [5], [], [6], [], [7], []] self.assertDatasetProduces(rebatched_dataset, expected_output)
def testFinalPartialBatchAfterRebatch(self, drop_remainder): dataset = dataset_ops.Dataset.range(34).batch( 32, drop_remainder=drop_remainder) rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4) self.assertEqual([[8] if drop_remainder else [None]], [ts.as_list() for ts in _flat_shapes(rebatched_dataset)]) expected_output = [[k for k in range(i, i + 8)] for i in range(0, 32, 8)] # pylint: disable=g-complex-comprehension if not drop_remainder: expected_output += [[32, 33]] self.assertDatasetProduces(rebatched_dataset, expected_output)
def testMultipleBatches(self): dataset = dataset_ops.Dataset.range(16).batch( 2, drop_remainder=True).batch(4, drop_remainder=True) self.assertEqual([[4, 2]], _flat_shapes(dataset)) rebatched_dataset = distribute._RebatchDataset(dataset, [2, 2]) self.assertEqual([[2, 2]], _flat_shapes(rebatched_dataset)) # Each element is a list of 2 elements where each element is a list of 2. expected_output = [[[0, 1], [2, 3]], [[4, 5], [6, 7]], [[8, 9], [10, 11]], [[12, 13], [14, 15]]] self.assertDatasetProduces(rebatched_dataset, expected_output)
def testBasic(self, drop_remainder): dataset = dataset_ops.Dataset.range(1024).batch( 32, drop_remainder=drop_remainder) rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4) self.assertEqual( [[8] if drop_remainder else [None]], [ts.as_list() for ts in _flat_shapes(rebatched_dataset)]) expected_output = [[k for k in range(i, i + 8)] for i in range(0, 1024, 8)] # pylint: disable=g-complex-comprehension self.assertDatasetProduces(rebatched_dataset, expected_output)
def testNestedDictionaryOutput(self): dataset = dataset_ops.Dataset.range(8).map( lambda x: {"a": x, "b": {"c": x + 1}}).batch(4, drop_remainder=True) rebatched_dataset = distribute._RebatchDataset(dataset, [2, 2]) self.assertEqual([[2], [2]], _flat_shapes(rebatched_dataset)) expected_output = [{"a": [0, 1], "b": {"c": [1, 2]}}, {"a": [2, 3], "b": {"c": [3, 4]}}, {"a": [4, 5], "b": {"c": [5, 6]}}, {"a": [6, 7], "b": {"c": [7, 8]}}] self.assertDatasetProduces(rebatched_dataset, expected_output)
def testWithUnknownBatchDimInSecondComponent(self): dataset0 = dataset_ops.Dataset.range(1024).batch(32, drop_remainder=True) dataset1 = dataset_ops.Dataset.range(1024).batch( 32, drop_remainder=False).apply(sleep.sleep(10)) dataset = dataset_ops.Dataset.zip((dataset0, dataset1)) rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4) expected_output = [[k for k in range(i, i + 8)] for i in range(0, 1024, 8)] # pylint: disable=g-complex-comprehension expected_output = [(x, x) for x in expected_output] self.assertDatasetProduces(rebatched_dataset, expected_output)
def testMapAndBatchWithCapturedInput(self): captured_t = variables.Variable(42) dataset = dataset_ops.Dataset.range(1024).apply( batching.map_and_batch(lambda x: captured_t, 32)) rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4) self.assertEqual([[None]], [ts.as_list() for ts in _flat_shapes(rebatched_dataset)]) expected_output = [[42 for _ in range(i, i + 8)] # pylint: disable=g-complex-comprehension for i in range(0, 1024, 8)] self.evaluate(variables.global_variables_initializer()) self.assertDatasetProduces( rebatched_dataset, expected_output, requires_initialization=True)
def testShardWithRebatch(self): # Tests that RebatchDatasetV2 is a passthrough op. dataset = dataset_ops.Dataset.list_files(self.test_filenames, shuffle=False) dataset = dataset.apply( testing.assert_next(["Shard", "FlatMap", "Batch", "Rebatch"])) dataset = dataset.flat_map(core_readers.TFRecordDataset) dataset = dataset.batch(5) dataset = distribute._RebatchDataset(dataset, batch_sizes=5) dataset = distribute._AutoShardDataset(dataset, 5, 3) nxt = self.getNext(dataset) self.evaluate(nxt())
def testBatchSizeNotDivisibleByNumReplicas2(self): dataset = dataset_ops.Dataset.range(32).batch(16, drop_remainder=True) rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=5) # This will rebatch into sub-batches of size 4, since # ceil(16 / 5) = 4. However, that means only the first 4 replicas will get # data. expected_output = [[k for k in range(i, i + 4)] for i in range(0, 16, 4)] expected_output.extend([[]]) # Last replica gets an empty batch expected_output.extend( [[k for k in range(i, i + 4)] for i in range(16, 32, 4)]) expected_output.extend([[]]) # Last replica gets an empty batch self.assertDatasetProduces(rebatched_dataset, expected_output)
def testFinalPartialBatchOriginal(self, drop_remainder): dataset = dataset_ops.Dataset.range(1032).batch( 32, drop_remainder=drop_remainder) rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=4) self.assertEqual([[32 if drop_remainder else None]], [ts.as_list() for ts in _flat_shapes(dataset)]) self.assertEqual( [[8 if drop_remainder else None]], [ts.as_list() for ts in _flat_shapes(rebatched_dataset)]) expected_output = [[k for k in range(i, i + 8)] for i in range(0, 1032, 8)] # pylint: disable=g-complex-comprehension self.assertDatasetProduces(rebatched_dataset, expected_output)
def testFinalPartialBatchOriginal(self, drop_remainder): dataset = dataset_ops.Dataset.range(1032).batch( 32, drop_remainder=drop_remainder) rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=4) self.assertEqual( [[32 if drop_remainder else None]], [ts.as_list() for ts in _flat_shapes(dataset)]) self.assertEqual( [[8 if drop_remainder else None]], [ts.as_list() for ts in _flat_shapes(rebatched_dataset)]) expected_output = [[k for k in range(i, i + 8)] for i in range(0, 1032, 8)] # pylint: disable=g-complex-comprehension self.assertDatasetProduces(rebatched_dataset, expected_output)
def testMapAndBatch(self, drop_remainder): dataset = dataset_ops.Dataset.range(1024).apply( batching.map_and_batch( math_ops.square, 32, drop_remainder=drop_remainder)) rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=4) self.assertEqual( [[32 if drop_remainder else None]], [ts.as_list() for ts in _flat_shapes(dataset)]) self.assertEqual( [[8 if drop_remainder else None]], [ts.as_list() for ts in _flat_shapes(rebatched_dataset)]) expected_output = [[k**2 for k in range(i, i + 8)] # pylint: disable=g-complex-comprehension for i in range(0, 1024, 8)] self.assertDatasetProduces(rebatched_dataset, expected_output)
def testZip(self, drop_remainder): dataset1 = dataset_ops.Dataset.range(64).batch( 8, drop_remainder=drop_remainder) dataset2 = dataset_ops.Dataset.range(32).batch( 8, drop_remainder=drop_remainder) dataset = dataset_ops.Dataset.zip((dataset1, dataset2)) rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=4) self.assertEqual( [[8], [8]] if drop_remainder else [[None], [None]], [ts.as_list() for ts in _flat_shapes(dataset)]) self.assertEqual( [[2], [2]] if drop_remainder else [[None], [None]], [ts.as_list() for ts in _flat_shapes(rebatched_dataset)]) expected_output = [([i, i + 1], [i, i + 1]) for i in range(0, 32, 2)] self.assertDatasetProduces(rebatched_dataset, expected_output)
def testConcatenateDifferentShapes(self, drop_remainder): dataset1 = dataset_ops.Dataset.range(64).batch( 16, drop_remainder=drop_remainder) dataset2 = dataset_ops.Dataset.range(32).batch( 8, drop_remainder=drop_remainder) dataset = dataset1.concatenate(dataset2) rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=4) self.assertEqual( [[None]], [ts.as_list() for ts in _flat_shapes(dataset)]) self.assertEqual( [[None]], [ts.as_list() for ts in _flat_shapes(rebatched_dataset)]) expected_output = ([[i, i + 1, i + 2, i + 3] for i in range(0, 64, 4)] + [[i, i + 1] for i in range(0, 32, 2)]) self.assertDatasetProduces(rebatched_dataset, expected_output)
def testMapAndBatchWithCapturedInput(self, drop_remainder): captured_t = variables.Variable(42) dataset = dataset_ops.Dataset.range(1024).apply( batching.map_and_batch( lambda x: captured_t, 32, drop_remainder=drop_remainder)) rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=4) self.assertEqual([[32 if drop_remainder else None]], [ts.as_list() for ts in _flat_shapes(dataset)]) self.assertEqual([[8 if drop_remainder else None]], [ts.as_list() for ts in _flat_shapes(rebatched_dataset)]) expected_output = [[42 for _ in range(i, i + 8)] # pylint: disable=g-complex-comprehension for i in range(0, 1024, 8)] self.evaluate(variables.global_variables_initializer()) self.assertDatasetProduces( rebatched_dataset, expected_output, requires_initialization=True)
def testGroupByWindowBatching(self, drop_remainder): dataset = dataset_ops.Dataset.from_tensor_slices( [[array_ops.constant(i, dtype=dtypes.int64)] * 3 for i in range(40)]) reduce_fn = lambda bucket_id, ds: ds.batch( batch_size=10, drop_remainder=drop_remainder) dataset = dataset.apply( grouping.group_by_window( key_func=lambda x: x[0] % 4, reduce_func=reduce_fn, window_size=10)) rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=2) self.assertEqual([[5, 3] if drop_remainder else [None, 3]], [ts.as_list() for ts in _flat_shapes(rebatched_dataset)]) # pylint: disable=g-complex-comprehension expected_output = [[[j + i * 4 + k * 20] * 3 for i in range(5)] for j in range(4) for k in range(2)] self.assertDatasetProduces(rebatched_dataset, expected_output)
def testInterleaveBatching(self, drop_remainder): dataset = dataset_ops.Dataset.range( 2).interleave(lambda _: dataset_ops.Dataset.range(32).batch( # pylint: disable=g-long-lambda 32, drop_remainder=drop_remainder), cycle_length=2) self.assertEqual( [[32 if drop_remainder else None]], [ts.as_list() for ts in _flat_shapes(dataset)]) # Two elements where each element is range(32) expected_output = [[k for k in range(32)] for _ in range(2)] # pylint: disable=g-complex-comprehension self.assertDatasetProduces(dataset, expected_output) rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=4) self.assertEqual( [[8 if drop_remainder else None]], [ts.as_list() for ts in _flat_shapes(rebatched_dataset)]) # List of 4 elements where each element is a list of 8 numbering from 0 to # 31 repeated twice. expected_output = [[k for k in range(i, i + 8)] # pylint: disable=g-complex-comprehension for i in range(0, 32, 8) # generates 4 elements for _ in range(2)] self.assertDatasetProduces(rebatched_dataset, expected_output)
def testFlatMapBatching(self, drop_remainder): dataset = dataset_ops.Dataset.range( 2).flat_map(lambda _: dataset_ops.Dataset.range(32).batch( # pylint: disable=g-long-lambda 32, drop_remainder=drop_remainder)) self.assertEqual( [[32 if drop_remainder else None]], [ts.as_list() for ts in _flat_shapes(dataset)]) # Two elements where each element is range(32) expected_output = [[k for k in range(32)] for _ in range(2)] # pylint: disable=g-complex-comprehension self.assertDatasetProduces(dataset, expected_output) rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=4) self.assertEqual( [[8 if drop_remainder else None]], [ts.as_list() for ts in _flat_shapes(rebatched_dataset)]) # Two elements where each element is a list of 4 elements where each element # is a list of 8. expected_output = [[k for k in range(i, i + 8)] # pylint: disable=g-complex-comprehension for _ in range(2) for i in range(0, 32, 8)] # generates 4 elements self.assertDatasetProduces(rebatched_dataset, expected_output)
def testMultipleBatches(self, drop_remainder): dataset = dataset_ops.Dataset.range(128).batch( 4, drop_remainder=drop_remainder) dataset = dataset.batch(8, drop_remainder=drop_remainder) self.assertEqual( [[8, 4]] if drop_remainder else [[None, None]], [ts.as_list() for ts in _flat_shapes(dataset)]) # Each element is a list of 8 elements where each element is a list of 4. expected_output = [[[j, j + 1, j + 2, j + 3] # pylint: disable=g-complex-comprehension for j in range(i, i + 32, 4)] # generates 8 elements for i in range(0, 128, 32)] self.assertDatasetProduces(dataset, expected_output) rebatched_dataset = distribute._RebatchDataset(dataset, 4) self.assertEqual( [[2, 4]] if drop_remainder else [[None, None]], [ts.as_list() for ts in _flat_shapes(rebatched_dataset)]) # Each element is a list of 2 elements where each element is a list of 4. expected_output = [[[j, j + 1, j + 2, j + 3] # pylint: disable=g-complex-comprehension for j in range(i, i + 8, 4)] # generates 2 elements for i in range(0, 128, 8)] self.assertDatasetProduces(rebatched_dataset, expected_output)
def testPaddedBatch(self, drop_remainder): dataset = dataset_ops.Dataset.range(128).batch(4).padded_batch( 8, padded_shapes=[5], drop_remainder=drop_remainder) rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=4) self.assertEqual( [[8, 5]] if drop_remainder else [[None, 5]], [ts.as_list() for ts in _flat_shapes(dataset)]) # Each element is a list of 8 elements in which each element is a list of 5 # elements, first four are numbers and the last one is a padded zero. expected_output = [[[j, j + 1, j + 2, j + 3, 0] # pylint: disable=g-complex-comprehension for j in range(i, i + 32, 4)] # generates 8 elements for i in range(0, 128, 32)] self.assertDatasetProduces(dataset, expected_output) self.assertEqual( [[2, 5]] if drop_remainder else [[None, 5]], [ts.as_list() for ts in _flat_shapes(rebatched_dataset)]) # Each element is a list of 2 elements in which each element is a list of 5 # elements, first four are numbers and the last one is a padded zero. expected_output = [[[j, j + 1, j + 2, j + 3, 0] # pylint: disable=g-complex-comprehension for j in range(i, i + 8, 4)] # generates 2 elements for i in range(0, 128, 8)] self.assertDatasetProduces(rebatched_dataset, expected_output)
def build_dataset(num_elements, batch_size): return distribute._RebatchDataset( dataset_ops.Dataset.range(num_elements).batch( 4 * batch_size, drop_remainder=True), num_workers=4)
def testScalarInputError(self, _): dataset = dataset_ops.Dataset.range(1024) with self.assertRaisesRegexp(ValueError, "at least one dimension"): distribute._RebatchDataset(dataset, num_workers=4)