def testComputeBatchSizeKnownAndMismatched(self):
   # Return -1 when different components have different batch sizes.
   dataset = dataset_ops.Dataset.range(32)
   dataset = dataset_ops.Dataset.zip((dataset.batch(4, drop_remainder=True),
                                      dataset.batch(8, drop_remainder=True)))
   batch_size = distribute.compute_batch_size(dataset)
   self.assertEqual(-1, self.evaluate(batch_size))
Esempio n. 2
0
 def testNoneDataset(self):
     # Some datasets, e.g. datasets with None tensors, have components without
     # output shapes. Test that this doesn't break computing batch size logic.
     dataset = dataset_ops.Dataset.range(4)
     dataset = dataset.map(lambda x: (x, None))
     dataset = dataset.batch(4, drop_remainder=True)
     batch_size = distribute.compute_batch_size(dataset)
     self.assertEqual(4, self.evaluate(batch_size))
Esempio n. 3
0
 def testComputeBatchSizeWithZipMismatched(self):
     dataset = dataset_ops.Dataset.range(32)
     dataset = dataset_ops.Dataset.zip((dataset.batch(4), dataset.batch(8)))
     batch_size = distribute.compute_batch_size(dataset)
     self.assertEqual(-1, self.evaluate(batch_size))
Esempio n. 4
0
 def testComputeBatchSizeWithPassthroughInvalid(self):
     dataset = dataset_ops.Dataset.range(32).batch(4)
     dataset = dataset.map(lambda x: x + 1)
     batch_size = distribute.compute_batch_size(dataset)
     self.assertEqual(-1, self.evaluate(batch_size))
Esempio n. 5
0
 def testComputeBatchSizeWithPassthrough(self):
     dataset = dataset_ops.Dataset.range(32).batch(4)
     dataset = dataset.take(5)
     batch_size = distribute.compute_batch_size(dataset)
     self.assertEqual(4, self.evaluate(batch_size))
Esempio n. 6
0
 def testComputeBatchSizeUnknown(self):
     dataset = dataset_ops.Dataset.range(32).batch(4)
     batch_size = distribute.compute_batch_size(dataset)
     self.assertEqual(4, self.evaluate(batch_size))
Esempio n. 7
0
 def testComputeBatchSizeKnown(self):
     # When drop_remainder=True, batch size can be inferred from the type spec.
     dataset = dataset_ops.Dataset.range(32).batch(4, drop_remainder=True)
     dataset = dataset_ops.Dataset.zip((dataset, dataset))
     batch_size = distribute.compute_batch_size(dataset)
     self.assertEqual(4, self.evaluate(batch_size))