def test_auto_init(self): epoch_counter = [0] seq_flow = DataStream.int_seq(0, 10, batch_size=2) map_flow = seq_flow.map( lambda x: (x + epoch_counter[0] * 10,)) def make_iterator(): epoch_counter[0] += 1 return map_flow it_flow = DataStream.generator(make_iterator) flow = it_flow.threaded(3) batches = [b[0] for b in flow] np.testing.assert_array_equal( [[10, 11], [12, 13], [14, 15], [16, 17], [18, 19]], batches) batches = [b[0] for b in flow] np.testing.assert_array_equal( [[20, 21], [22, 23], [24, 25], [26, 27], [28, 29]], batches) flow.close() batches = [b[0] for b in flow] np.testing.assert_array_equal( [[40, 41], [42, 43], [44, 45], [46, 47], [48, 49]], batches) flow.close()
def test_copy(self): source = DataStream.int_seq(5, batch_size=3) stream = source.threaded(3) self.assertIs(stream.source, source) self.assertEqual(stream.prefetch, 3) stream2 = stream.copy(prefetch=1) self.assertIsInstance(stream2, ThreadingDataStream) self.assertIs(stream2.source, source) self.assertEqual(stream2.prefetch, 1)
def test_iter_reentrant_warn(self): stream = DataStream.int_seq(5, batch_size=3) # test open and close, no warning iterator = iter(stream) np.testing.assert_equal(next(iterator)[0], [0, 1, 2]) iterator.close() with warnings.catch_warnings(record=True) as w: warnings.simplefilter('always') batches = list(stream) self.assertEqual(len(batches), 2) np.testing.assert_equal(batches[0][0], [0, 1, 2]) np.testing.assert_equal(batches[1][0], [3, 4]) self.assertEqual(len(w), 0) # test open without close, cause warning iterator = iter(stream) np.testing.assert_equal(next(iterator)[0], [0, 1, 2]) with warnings.catch_warnings(record=True) as w: warnings.simplefilter('always') batches = list(stream) self.assertEqual(len(batches), 2) np.testing.assert_equal(batches[0][0], [0, 1, 2]) np.testing.assert_equal(batches[1][0], [3, 4]) self.assertEqual(len(w), 1) self.assertTrue(issubclass(w[-1].category, UserWarning)) self.assertRegex( str(w[-1].message), r'Another iterator of the DataStream .* is still active, ' r'will close it automatically.') with pytest.raises(StopIteration): _ = next(iterator) # this iterator should have been closed # test no warning the second time iterator = iter(stream) np.testing.assert_equal(next(iterator)[0], [0, 1, 2]) with warnings.catch_warnings(record=True) as w: warnings.simplefilter('always') batches = list(stream) self.assertEqual(len(batches), 2) np.testing.assert_equal(batches[0][0], [0, 1, 2]) np.testing.assert_equal(batches[1][0], [3, 4]) self.assertEqual(len(w), 0)
def test_copy(self): rs = np.random.RandomState() source = DataStream.int_seq(5, batch_size=3) mapped = MapperDataStream(source, lambda *args: args) mapped2 = mapped.copy(batch_size=7, array_count=1, data_shapes=((3, ), ), data_length=11, random_state=rs) self.assertEqual(mapped2.data_length, 11) self.assertEqual(mapped2.batch_size, 7) self.assertEqual(mapped2.array_count, 1) self.assertEqual(mapped2.data_shapes, ((3, ), )) self.assertIs(mapped2.random_state, rs)
def test_iterator(self): class _MyError(Exception): pass epoch_counter = [0] external_counter = [1] seq_flow = DataStream.int_seq(0, 10, batch_size=2) map_flow = seq_flow.map( lambda x: (x + epoch_counter[0] * 10 + external_counter[0] * 100,)) def make_iterator(): epoch_counter[0] += 1 return map_flow it_flow = DataStream.generator(make_iterator) with it_flow.threaded(prefetch=2) as flow: # the first epoch, expect 0 .. 10 np.testing.assert_array_equal( [[110, 111], [112, 113], [114, 115], [116, 117], [118, 119]], [a[0] for a in flow] ) time.sleep(.1) external_counter[0] += 1 # the second epoch, the epoch counter should affect more than # the external counter np.testing.assert_array_equal( # having `prefetch = 2` should affect 3 items, because # while the queue size is 2, there are 1 additional prefetched # item waiting to be enqueued [[120, 121], [122, 123], [124, 125], [226, 227], [228, 229]], [a[0] for a in flow] ) time.sleep(.1) external_counter[0] += 1 # the third epoch, we shall carry out an incomplete epoch by break for a in flow: np.testing.assert_array_equal([230, 231], a[0]) break time.sleep(.1) external_counter[0] += 1 # verify that the epoch counter increases after break for i, (a,) in enumerate(flow): # because the interruption is not well-predictable under # multi-threading context, we shall have a weaker verification # than the above self.assertTrue((340 + i * 2 == a[0]) or (440 + i * 2 == a[0])) self.assertTrue((341 + i * 2 == a[1]) or (441 + i * 2 == a[1])) time.sleep(.1) external_counter[0] += 1 # carry out the fourth, incomplete epoch by error try: for a in flow: np.testing.assert_array_equal([450, 451], a[0]) raise _MyError() except _MyError: pass time.sleep(.1) external_counter[0] += 1 # verify that the epoch counter increases after error for i, (a,) in enumerate(flow): self.assertTrue((560 + i * 2 == a[0]) or (660 + i * 2 == a[0])) self.assertTrue((561 + i * 2 == a[1]) or (661 + i * 2 == a[1]))
def test_stream(self): # argument validation with pytest.raises(ValueError, match='At least one data stream should be ' 'specified'): _ = GatherDataStream([]) with pytest.raises(TypeError, match='The 1-th element of `streams` is not ' 'an instance of DataStream: <object.*>'): _ = GatherDataStream([DataStream.int_seq(5, batch_size=3), object()]) def my_generator(): if False: yield with pytest.raises(ValueError, match='Inconsistent batch size among the specified ' 'streams: encountered 4 at the 3-th stream, ' 'but has already encountered 3 before.'): _ = GatherDataStream([ DataStream.generator(my_generator), DataStream.int_seq(5, batch_size=3), DataStream.generator(my_generator), DataStream.int_seq(5, batch_size=4), ]) with pytest.raises(ValueError, match='Inconsistent data length among the specified ' 'streams: encountered 6 at the 3-th stream, ' 'but has already encountered 5 before.'): _ = GatherDataStream([ DataStream.generator(my_generator), DataStream.int_seq(5, batch_size=3), DataStream.generator(my_generator), DataStream.int_seq(6, batch_size=3), ]) # test property inheritance rs0 = np.random.RandomState(1234) x = rs0.normal(size=[5, 1]) y = rs0.normal(size=[5, 2]) z = rs0.normal(size=[5, 3]) rs = np.random.RandomState(1234) stream_x = DataStream.arrays([x], batch_size=3, random_state=rs) stream_yz = DataStream.arrays([y, z], batch_size=3) stream = GatherDataStream([stream_x, stream_yz]) self.assertTupleEqual(stream.streams, (stream_x, stream_yz)) self.assertEqual(stream.batch_size, 3) self.assertEqual(stream.data_shapes, ((1,), (2,), (3,))) self.assertEqual(stream.data_length, 5) self.assertEqual(stream.array_count, 3) self.assertIs(stream.random_state, rs) arrays = stream.get_arrays() self.assertEqual(len(arrays), 3) np.testing.assert_equal(arrays[0], x) np.testing.assert_equal(arrays[1], y) np.testing.assert_equal(arrays[2], z) # test no property to inherit stream_1 = DataStream.generator(my_generator) stream_2 = DataStream.generator(my_generator) stream = GatherDataStream([stream_1, stream_2]) for attr in ('batch_size', 'array_count', 'data_length', 'data_shapes', 'random_state'): self.assertIsNone(getattr(stream, attr)) stream = GatherDataStream([stream_1, stream_2], random_state=rs) self.assertIs(stream.random_state, rs)