def test_serial_composes(self): """Check that data.Serial works inside another data.Serial.""" dataset = lambda _: ((i, i + 1) for i in range(10)) serial1 = data.Serial(dataset, data.Shuffle(3)) batches = data.Serial(serial1, data.Batch(10)) batch = next(batches()) self.assertLen(batch, 2) self.assertEqual(batch[0].shape, (10, ))
def test_serial(self): dataset = lambda _: ((i, i + 1) for i in range(10)) batches = data.Serial(dataset, data.Shuffle(3), data.Batch(10)) batch = next(batches()) self.assertLen(batch, 2) self.assertEqual(batch[0].shape, (10, ))