def test_basic_two_workers(self): data = self.get_test_data(1000) batch_size = 100 num_shards = 2 reader = NpArrayReader(data, batch_size=batch_size, num_shards=num_shards) splits = [ reader._get_split(reader.data, i, batch_size) for i in range(10) ] streamer = DataStreamer(reader, num_workers=num_shards) for _i, batch in enumerate(streamer): match = False for split in splits: try: self.assert_batch_equal(split, batch, 0, batch_size) except Exception: pass else: match = True break self.assertTrue(match) self.assertEqual(9, _i)
def test_shard_drop_small(self): n = 1000 data = self.get_test_data(n) batch_size = 100 num_shards = 3 shard_size = n // num_shards + 1 reader = NpArrayReader(data, batch_size=batch_size, num_shards=num_shards) num_batches = 0 for shard in range(num_shards): for i, batch in enumerate(reader.get_shard(shard)): self.assert_batch_equal( data, batch, shard * shard_size + i * batch_size, batch_size ) num_batches += 1 self.assertEqual(9, num_batches)
def test_drop_small(self): data = self.get_test_data(999) batch_size = 100 reader = NpArrayReader(data, batch_size=batch_size) for i, batch in enumerate(reader): self.assert_batch_equal(data, batch, i * batch_size, batch_size) self.assertEqual(8, i)
def test_basic(self): data = self.get_test_data(1000) batch_size = 100 reader = NpArrayReader(data, batch_size=batch_size) for i, batch in enumerate(reader): self.assert_batch_equal(data, batch, i * batch_size, batch_size) self.assertEqual(9, i)
def test_drop_small_one_worker(self): data = self.get_test_data(999) batch_size = 100 reader = NpArrayReader(data, batch_size=batch_size, num_shards=1) streamer = DataStreamer(reader, num_workers=1) for i, batch in enumerate(streamer): self.assert_batch_equal(data, batch, i * batch_size, batch_size) self.assertEqual(8, i)
def test_not_drop_small(self): data = self.get_test_data(999) batch_size = 100 reader = NpArrayReader(data, batch_size=batch_size, drop_small=False) streamer = DataStreamer(reader) for i, batch in enumerate(streamer): self.assert_batch_equal(data, batch, i * batch_size, batch_size if i != 9 else 99) self.assertEqual(9, i)
def test_basic_two_workers(self): data = self.get_test_data(1000) batch_size = 100 num_shards = 2 reader = NpArrayReader(data, batch_size=batch_size, num_shards=num_shards) splits = [reader._get_split(reader.data, i, batch_size) for i in range(10)] streamer = DataStreamer(reader, num_workers=num_shards) for _i, batch in enumerate(streamer): match = False for split in splits: try: self.assert_batch_equal(split, batch, 0, batch_size) except Exception: pass else: match = True break self.assertTrue(match) self.assertEqual(9, _i)