Ejemplo n.º 1
0
 def testNestedDictionaryOutput(self):
     dataset = dataset_ops.Dataset.range(8).map(lambda x: {
         "a": x,
         "b": {
             "c": x + 1
         }
     }).batch(4)
     rebatched_dataset = distribute._LegacyRebatchDataset(dataset,
                                                          num_replicas=2)
     expected_output = [{
         "a": [0, 1],
         "b": {
             "c": [1, 2]
         }
     }, {
         "a": [2, 3],
         "b": {
             "c": [3, 4]
         }
     }, {
         "a": [4, 5],
         "b": {
             "c": [5, 6]
         }
     }, {
         "a": [6, 7],
         "b": {
             "c": [7, 8]
         }
     }]
     self.assertDatasetProduces(rebatched_dataset, expected_output)
Ejemplo n.º 2
0
 def testNoneDataset(self):
     # Some datasets, e.g. datasets with None tensors, have components without
     # output shapes. Test that this doesn't break rebatching shape inference
     # logic.
     dataset = dataset_ops.Dataset.range(4)
     dataset = dataset.map(lambda x: (x, None))
     dataset = dataset.batch(4, drop_remainder=True)
     _ = distribute._LegacyRebatchDataset(dataset, num_replicas=2)
 def testTupleOutput(self):
   dataset = dataset_ops.Dataset.range(1024).map(lambda x: (x, x)).batch(32)
   rebatched_dataset = distribute._LegacyRebatchDataset(
       dataset, num_replicas=4)
   expected_output = [([k for k in range(i, i + 8)],  # pylint: disable=g-complex-comprehension
                       [k for k in range(i, i + 8)])
                      for i in range(0, 1024, 8)]
   self.assertDatasetProduces(rebatched_dataset, expected_output)
Ejemplo n.º 4
0
 def testCanHandleUnknownDims(self):
     dataset = dataset_ops.Dataset.range(1000)
     dataset = dataset.batch(10, drop_remainder=False)
     dataset = dataset.batch(10, drop_remainder=False)
     self.assertEqual([[None, None]], _flat_shapes(dataset))
     rebatched_dataset = distribute._LegacyRebatchDataset(dataset,
                                                          num_replicas=4)
     # Note that we are just testing the dataset shapes, not the actual output.
     self.assertEqual([[None, None]], _flat_shapes(rebatched_dataset))
Ejemplo n.º 5
0
 def testBatchNotDivisibleByNumReplicas(self, drop_remainder):
     dataset = dataset_ops.Dataset.range(8).batch(
         4, drop_remainder=drop_remainder)
     rebatched_dataset = distribute._LegacyRebatchDataset(dataset,
                                                          num_replicas=3)
     self.assertEqual([[None]], _flat_shapes(rebatched_dataset))
     # This rebatches into sub-batches of size 2, since ceil(4 / 3) = 2. However,
     # this means that only the first 2 replicas will get data.
     expected_output = [[0, 1], [2, 3], [], [4, 5], [6, 7], []]
     self.assertDatasetProduces(rebatched_dataset, expected_output)
Ejemplo n.º 6
0
    def testBasic(self, drop_remainder):
        dataset = dataset_ops.Dataset.range(8).batch(
            4, drop_remainder=drop_remainder)
        rebatched_dataset = distribute._LegacyRebatchDataset(dataset,
                                                             num_replicas=2)

        expected_shapes = [[2]] if drop_remainder else [[None]]
        self.assertEqual(expected_shapes, _flat_shapes(rebatched_dataset))

        expected_output = [[0, 1], [2, 3], [4, 5], [6, 7]]
        self.assertDatasetProduces(rebatched_dataset, expected_output)
Ejemplo n.º 7
0
 def testFileShardingWithLegacyRebatch(self):
   # Tests that RebatchDatasetV1 is a passthrough op.
   self._setUpFiles(num_files=5, num_records_per_file=10)
   dataset = dataset_ops.Dataset.list_files(self.test_filenames, shuffle=False)
   dataset = dataset.apply(
       testing.assert_next(["Shard", "FlatMap", "Batch", "Rebatch"]))
   dataset = dataset.flat_map(core_readers.TFRecordDataset)
   dataset = dataset.batch(5)
   dataset = distribute._LegacyRebatchDataset(dataset, num_replicas=5)
   dataset = distribute._AutoShardDataset(dataset, 5, 3)
   expected = [[self._record(3, i)] for i in range(10)]
   self.assertDatasetProduces(dataset, expected)
 def testShardWithLegacyRebatch(self):
     # Tests that RebatchDatasetV1 is a passthrough op.
     dataset = dataset_ops.Dataset.list_files(self.test_filenames,
                                              shuffle=False)
     dataset = dataset.apply(
         testing.assert_next(["Shard", "FlatMap", "Batch", "Rebatch"]))
     dataset = dataset.flat_map(core_readers.TFRecordDataset)
     dataset = dataset.batch(5)
     dataset = distribute._LegacyRebatchDataset(dataset, num_replicas=1)
     dataset = distribute._AutoShardDataset(dataset, 5, 3)
     nxt = self.getNext(dataset)
     self.evaluate(nxt())
Ejemplo n.º 9
0
 def testFinalPartialBatchAfterRebatch(self, drop_remainder):
     dataset = dataset_ops.Dataset.range(9).batch(
         4, drop_remainder=drop_remainder)
     rebatched_dataset = distribute._LegacyRebatchDataset(dataset,
                                                          num_replicas=2)
     self.assertEqual([[2] if drop_remainder else [None]],
                      _flat_shapes(rebatched_dataset))
     if drop_remainder:
         expected_output = [[0, 1], [2, 3], [4, 5], [6, 7]]
     else:
         expected_output = [[0, 1], [2, 3], [4, 5], [6, 7], [8], []]
     self.assertDatasetProduces(rebatched_dataset, expected_output)
 def testCanHandleUnknownRank(self):
   dataset = dataset_ops.Dataset.from_tensors("xxx")
   # decode_image results in a tensor of completely unknown shape (i.e. unknown
   # rank)
   dataset = dataset.map(image_ops.decode_image)
   self.assertEqual([tensor_shape.TensorShape(None)],
                    nest.flatten(
                        dataset_ops.get_legacy_output_shapes(dataset)))
   rebatched_dataset = distribute._LegacyRebatchDataset(
       dataset, num_replicas=4)
   # Note that we are just testing the dataset shapes, not the actual output.
   self.assertEqual(
       [tensor_shape.TensorShape(None)],
       nest.flatten(dataset_ops.get_legacy_output_shapes(rebatched_dataset)))
Ejemplo n.º 11
0
    def testMultipleBatches(self):
        dataset = dataset_ops.Dataset.range(16).batch(2).batch(4)
        self.assertEqual([[None, None]], _flat_shapes(dataset))

        # Each element is a list of 4 elements where each element is a list of 2.
        expected_output = [[[0, 1], [2, 3], [4, 5], [6, 7]],
                           [[8, 9], [10, 11], [12, 13], [14, 15]]]
        self.assertDatasetProduces(dataset, expected_output)

        rebatched_dataset = distribute._LegacyRebatchDataset(dataset, 2)
        self.assertEqual([[None, None]], _flat_shapes(rebatched_dataset))
        # Each element is a list of 2 elements where each element is a list of 2.
        expected_output = [[[0, 1], [2, 3]], [[4, 5], [6, 7]],
                           [[8, 9], [10, 11]], [[12, 13], [14, 15]]]
        self.assertDatasetProduces(rebatched_dataset, expected_output)
Ejemplo n.º 12
0
    def testRaggedTensorDataset(self):
        # Set up a dataset that produces ragged tensors with a static batch size.
        row_lengths = np.random.randint(8, size=128)
        values = np.random.normal(size=np.sum(row_lengths)).astype(np.float32)
        dataset = dataset_ops.Dataset.from_tensor_slices(
            ragged_tensor.RaggedTensor.from_row_lengths(values, row_lengths))
        dataset = dataset.batch(32, drop_remainder=True)

        # The map changes the internal representation of the ragged tensor.
        # This test will fail if we don't normalize the tensor representation.
        dataset = dataset.map(lambda x: x)

        dataset = distribute._LegacyRebatchDataset(dataset, num_replicas=8)
        # After rebatching, batch size is now 4.
        expected_output = []
        value_index = 0
        for batch_row_lengths in row_lengths.reshape((-1, 4)):
            num_values = np.sum(batch_row_lengths)
            expected_output.append(
                ragged_tensor.RaggedTensor.from_row_lengths(
                    values[value_index:(value_index + num_values)],
                    batch_row_lengths))
            value_index += num_values
        self.assertDatasetProduces(dataset, expected_output)
Ejemplo n.º 13
0
 def build_dataset(num_elements, batch_size):
     return distribute._LegacyRebatchDataset(
         dataset_ops.Dataset.range(num_elements).batch(
             4 * batch_size, drop_remainder=True),
         num_replicas=4)
Ejemplo n.º 14
0
 def testScalarInputError(self):
     dataset = dataset_ops.Dataset.range(1024)
     distribute._LegacyRebatchDataset(dataset.batch(4), num_replicas=4)
     with self.assertRaises(ValueError):
         distribute._LegacyRebatchDataset(dataset, num_replicas=4)
Ejemplo n.º 15
0
 def testScalarInputError(self):
     dataset = dataset_ops.Dataset.range(1024)
     distribute._LegacyRebatchDataset(dataset.batch(4), num_replicas=4)
     with self.assertRaisesRegex(ValueError, ("You can fix the issue "
                                              "by adding the `batch`")):
         distribute._LegacyRebatchDataset(dataset, num_replicas=4)