Ejemplo n.º 1
0
 def test_generator_datashape(self, n, bs):
     ds = SVHNDataset("test")
     ds._images = np.random.randint(low=0, high=255, size=(n, 32, 32, 3))
     ds.labels = np.random.randint(low=1, high=10, size=(n, 1))
     ds_gen = ds.generator(batch_size=bs, flatten=False, ae=False)
     for i in range(len(ds_gen)):
         if i == len(ds_gen) - 1:
             if n % bs == 0:
                 assert (bs, 32, 32, 3) == ds_gen[i][0].shape
                 assert (bs, 1) == ds_gen[i][1].shape
             else:
                 assert (n % bs, 32, 32, 3) == ds_gen[i][0].shape
                 assert (n % bs, 1) == ds_gen[i][1].shape
         else:
             assert (bs, 32, 32, 3) == ds_gen[i][0].shape
             assert (bs, 1) == ds_gen[i][1].shape
Ejemplo n.º 2
0
 def test_generator_batch(self, n, bs):
     ds = SVHNDataset("test")
     ds._images = np.random.randint(low=0, high=255, size=(n, 32, 32, 3))
     ds.labels = np.random.randint(low=1, high=10, size=(n, 1))
     assert np.ceil(n / bs) == len(ds.generator(batch_size=bs))