def testComputeBatchSizeKnownAndMismatched(self): # Return -1 when different components have different batch sizes. dataset = dataset_ops.Dataset.range(32) dataset = dataset_ops.Dataset.zip((dataset.batch(4, drop_remainder=True), dataset.batch(8, drop_remainder=True))) batch_size = distribute.compute_batch_size(dataset) self.assertEqual(-1, self.evaluate(batch_size))
def testNoneDataset(self): # Some datasets, e.g. datasets with None tensors, have components without # output shapes. Test that this doesn't break computing batch size logic. dataset = dataset_ops.Dataset.range(4) dataset = dataset.map(lambda x: (x, None)) dataset = dataset.batch(4, drop_remainder=True) batch_size = distribute.compute_batch_size(dataset) self.assertEqual(4, self.evaluate(batch_size))
def testComputeBatchSizeWithZipMismatched(self): dataset = dataset_ops.Dataset.range(32) dataset = dataset_ops.Dataset.zip((dataset.batch(4), dataset.batch(8))) batch_size = distribute.compute_batch_size(dataset) self.assertEqual(-1, self.evaluate(batch_size))
def testComputeBatchSizeWithPassthroughInvalid(self): dataset = dataset_ops.Dataset.range(32).batch(4) dataset = dataset.map(lambda x: x + 1) batch_size = distribute.compute_batch_size(dataset) self.assertEqual(-1, self.evaluate(batch_size))
def testComputeBatchSizeWithPassthrough(self): dataset = dataset_ops.Dataset.range(32).batch(4) dataset = dataset.take(5) batch_size = distribute.compute_batch_size(dataset) self.assertEqual(4, self.evaluate(batch_size))
def testComputeBatchSizeUnknown(self): dataset = dataset_ops.Dataset.range(32).batch(4) batch_size = distribute.compute_batch_size(dataset) self.assertEqual(4, self.evaluate(batch_size))
def testComputeBatchSizeKnown(self): # When drop_remainder=True, batch size can be inferred from the type spec. dataset = dataset_ops.Dataset.range(32).batch(4, drop_remainder=True) dataset = dataset_ops.Dataset.zip((dataset, dataset)) batch_size = distribute.compute_batch_size(dataset) self.assertEqual(4, self.evaluate(batch_size))