示例#1
0
 def test_batch_data_padding(self):
     dataset = (([1] * (10 - i), i + 1) for i in range(10))
     batches = data.batch(dataset, 10)
     batch = next(batches)
     self.assertEqual(batch[0].shape, (10, 10))
     self.assertTrue(np.array_equal(batch[0][-1],
                                    np.asarray([1] + 9 * [0])))
示例#2
0
 def test_batch_exception_size(self):
     dataset = ((i, i + 1) for i in range(10))
     with self.assertRaises(ValueError):
         batches = data.batch(dataset, 0)
         next(batches)
示例#3
0
 def test_batch_data(self):
     dataset = ((i, i + 1) for i in range(10))
     batches = data.batch(dataset, 10)
     batch = next(batches)
     self.assertLen(batch, 2)
     self.assertEqual(batch[0].shape, (10, ))