def test_batch_data_padding(self): dataset = (([1] * (10 - i), i + 1) for i in range(10)) batches = data.batch(dataset, 10) batch = next(batches) self.assertEqual(batch[0].shape, (10, 10)) self.assertTrue(np.array_equal(batch[0][-1], np.asarray([1] + 9 * [0])))
def test_batch_exception_size(self): dataset = ((i, i + 1) for i in range(10)) with self.assertRaises(ValueError): batches = data.batch(dataset, 0) next(batches)
def test_batch_data(self): dataset = ((i, i + 1) for i in range(10)) batches = data.batch(dataset, 10) batch = next(batches) self.assertLen(batch, 2) self.assertEqual(batch[0].shape, (10, ))