Ejemplo n.º 1
0
  def testBatchIndices(self):
    ds = datasets.Dataset(data=xrange(50), labels=np.random.randint(2, size=50))
    self.assertEqual(50, ds.size)

    batch_indices = ds.batch_indices(100, 20)
    self.assertLen(batch_indices, 100)
    self.assertLen(batch_indices[0], 20)
    for first_batch, second_batch in zip(batch_indices[0], batch_indices[1]):
      self.assertNotEqual(first_batch, second_batch)
Ejemplo n.º 2
0
 def testBatchIndices_BadData(self):
   ds = datasets.Dataset(data=xrange(10), labels=np.random.randint(2, size=50))
   with self.assertRaises(ValueError):
     ds.batch_indices(5, 10)