def testBatchIndices(self): ds = datasets.Dataset(data=xrange(50), labels=np.random.randint(2, size=50)) self.assertEqual(50, ds.size) batch_indices = ds.batch_indices(100, 20) self.assertLen(batch_indices, 100) self.assertLen(batch_indices[0], 20) for first_batch, second_batch in zip(batch_indices[0], batch_indices[1]): self.assertNotEqual(first_batch, second_batch)
def testBatchIndices_BadData(self): ds = datasets.Dataset(data=xrange(10), labels=np.random.randint(2, size=50)) with self.assertRaises(ValueError): ds.batch_indices(5, 10)