def testNestedDictionaryOutput(self): dataset = dataset_ops.Dataset.range(8).map(lambda x: { "a": x, "b": { "c": x + 1 } }).batch(4) rebatched_dataset = distribute._LegacyRebatchDataset(dataset, num_replicas=2) expected_output = [{ "a": [0, 1], "b": { "c": [1, 2] } }, { "a": [2, 3], "b": { "c": [3, 4] } }, { "a": [4, 5], "b": { "c": [5, 6] } }, { "a": [6, 7], "b": { "c": [7, 8] } }] self.assertDatasetProduces(rebatched_dataset, expected_output)
def testNoneDataset(self): # Some datasets, e.g. datasets with None tensors, have components without # output shapes. Test that this doesn't break rebatching shape inference # logic. dataset = dataset_ops.Dataset.range(4) dataset = dataset.map(lambda x: (x, None)) dataset = dataset.batch(4, drop_remainder=True) _ = distribute._LegacyRebatchDataset(dataset, num_replicas=2)
def testTupleOutput(self): dataset = dataset_ops.Dataset.range(1024).map(lambda x: (x, x)).batch(32) rebatched_dataset = distribute._LegacyRebatchDataset( dataset, num_replicas=4) expected_output = [([k for k in range(i, i + 8)], # pylint: disable=g-complex-comprehension [k for k in range(i, i + 8)]) for i in range(0, 1024, 8)] self.assertDatasetProduces(rebatched_dataset, expected_output)
def testCanHandleUnknownDims(self): dataset = dataset_ops.Dataset.range(1000) dataset = dataset.batch(10, drop_remainder=False) dataset = dataset.batch(10, drop_remainder=False) self.assertEqual([[None, None]], _flat_shapes(dataset)) rebatched_dataset = distribute._LegacyRebatchDataset(dataset, num_replicas=4) # Note that we are just testing the dataset shapes, not the actual output. self.assertEqual([[None, None]], _flat_shapes(rebatched_dataset))
def testBatchNotDivisibleByNumReplicas(self, drop_remainder): dataset = dataset_ops.Dataset.range(8).batch( 4, drop_remainder=drop_remainder) rebatched_dataset = distribute._LegacyRebatchDataset(dataset, num_replicas=3) self.assertEqual([[None]], _flat_shapes(rebatched_dataset)) # This rebatches into sub-batches of size 2, since ceil(4 / 3) = 2. However, # this means that only the first 2 replicas will get data. expected_output = [[0, 1], [2, 3], [], [4, 5], [6, 7], []] self.assertDatasetProduces(rebatched_dataset, expected_output)
def testBasic(self, drop_remainder): dataset = dataset_ops.Dataset.range(8).batch( 4, drop_remainder=drop_remainder) rebatched_dataset = distribute._LegacyRebatchDataset(dataset, num_replicas=2) expected_shapes = [[2]] if drop_remainder else [[None]] self.assertEqual(expected_shapes, _flat_shapes(rebatched_dataset)) expected_output = [[0, 1], [2, 3], [4, 5], [6, 7]] self.assertDatasetProduces(rebatched_dataset, expected_output)
def testFileShardingWithLegacyRebatch(self): # Tests that RebatchDatasetV1 is a passthrough op. self._setUpFiles(num_files=5, num_records_per_file=10) dataset = dataset_ops.Dataset.list_files(self.test_filenames, shuffle=False) dataset = dataset.apply( testing.assert_next(["Shard", "FlatMap", "Batch", "Rebatch"])) dataset = dataset.flat_map(core_readers.TFRecordDataset) dataset = dataset.batch(5) dataset = distribute._LegacyRebatchDataset(dataset, num_replicas=5) dataset = distribute._AutoShardDataset(dataset, 5, 3) expected = [[self._record(3, i)] for i in range(10)] self.assertDatasetProduces(dataset, expected)
def testShardWithLegacyRebatch(self): # Tests that RebatchDatasetV1 is a passthrough op. dataset = dataset_ops.Dataset.list_files(self.test_filenames, shuffle=False) dataset = dataset.apply( testing.assert_next(["Shard", "FlatMap", "Batch", "Rebatch"])) dataset = dataset.flat_map(core_readers.TFRecordDataset) dataset = dataset.batch(5) dataset = distribute._LegacyRebatchDataset(dataset, num_replicas=1) dataset = distribute._AutoShardDataset(dataset, 5, 3) nxt = self.getNext(dataset) self.evaluate(nxt())
def testFinalPartialBatchAfterRebatch(self, drop_remainder): dataset = dataset_ops.Dataset.range(9).batch( 4, drop_remainder=drop_remainder) rebatched_dataset = distribute._LegacyRebatchDataset(dataset, num_replicas=2) self.assertEqual([[2] if drop_remainder else [None]], _flat_shapes(rebatched_dataset)) if drop_remainder: expected_output = [[0, 1], [2, 3], [4, 5], [6, 7]] else: expected_output = [[0, 1], [2, 3], [4, 5], [6, 7], [8], []] self.assertDatasetProduces(rebatched_dataset, expected_output)
def testCanHandleUnknownRank(self): dataset = dataset_ops.Dataset.from_tensors("xxx") # decode_image results in a tensor of completely unknown shape (i.e. unknown # rank) dataset = dataset.map(image_ops.decode_image) self.assertEqual([tensor_shape.TensorShape(None)], nest.flatten( dataset_ops.get_legacy_output_shapes(dataset))) rebatched_dataset = distribute._LegacyRebatchDataset( dataset, num_replicas=4) # Note that we are just testing the dataset shapes, not the actual output. self.assertEqual( [tensor_shape.TensorShape(None)], nest.flatten(dataset_ops.get_legacy_output_shapes(rebatched_dataset)))
def testMultipleBatches(self): dataset = dataset_ops.Dataset.range(16).batch(2).batch(4) self.assertEqual([[None, None]], _flat_shapes(dataset)) # Each element is a list of 4 elements where each element is a list of 2. expected_output = [[[0, 1], [2, 3], [4, 5], [6, 7]], [[8, 9], [10, 11], [12, 13], [14, 15]]] self.assertDatasetProduces(dataset, expected_output) rebatched_dataset = distribute._LegacyRebatchDataset(dataset, 2) self.assertEqual([[None, None]], _flat_shapes(rebatched_dataset)) # Each element is a list of 2 elements where each element is a list of 2. expected_output = [[[0, 1], [2, 3]], [[4, 5], [6, 7]], [[8, 9], [10, 11]], [[12, 13], [14, 15]]] self.assertDatasetProduces(rebatched_dataset, expected_output)
def testRaggedTensorDataset(self): # Set up a dataset that produces ragged tensors with a static batch size. row_lengths = np.random.randint(8, size=128) values = np.random.normal(size=np.sum(row_lengths)).astype(np.float32) dataset = dataset_ops.Dataset.from_tensor_slices( ragged_tensor.RaggedTensor.from_row_lengths(values, row_lengths)) dataset = dataset.batch(32, drop_remainder=True) # The map changes the internal representation of the ragged tensor. # This test will fail if we don't normalize the tensor representation. dataset = dataset.map(lambda x: x) dataset = distribute._LegacyRebatchDataset(dataset, num_replicas=8) # After rebatching, batch size is now 4. expected_output = [] value_index = 0 for batch_row_lengths in row_lengths.reshape((-1, 4)): num_values = np.sum(batch_row_lengths) expected_output.append( ragged_tensor.RaggedTensor.from_row_lengths( values[value_index:(value_index + num_values)], batch_row_lengths)) value_index += num_values self.assertDatasetProduces(dataset, expected_output)
def build_dataset(num_elements, batch_size): return distribute._LegacyRebatchDataset( dataset_ops.Dataset.range(num_elements).batch( 4 * batch_size, drop_remainder=True), num_replicas=4)
def testScalarInputError(self): dataset = dataset_ops.Dataset.range(1024) distribute._LegacyRebatchDataset(dataset.batch(4), num_replicas=4) with self.assertRaises(ValueError): distribute._LegacyRebatchDataset(dataset, num_replicas=4)
def testScalarInputError(self): dataset = dataset_ops.Dataset.range(1024) distribute._LegacyRebatchDataset(dataset.batch(4), num_replicas=4) with self.assertRaisesRegex(ValueError, ("You can fix the issue " "by adding the `batch`")): distribute._LegacyRebatchDataset(dataset, num_replicas=4)