def test_select(self): # additional tests when source stream does not report array_count # and data_shapes def mapper(x, y, z): return x + y, y - z x = np.random.normal(size=[5, 3]) y = np.random.normal(size=[5, 1]) z = np.random.normal(size=[5, 3]) source = DataStream.arrays([x, y, z], batch_size=3).map(mapper) self.assertIsNone(source.data_shapes) self.assertIsNone(source.array_count) stream = source.select([-1, 0, 1]) self.assertEqual(stream.array_count, 3) self.assertIsNone(stream.data_shapes) a, b, c = stream.get_arrays() np.testing.assert_allclose(a, y - z) np.testing.assert_allclose(b, x + y) np.testing.assert_allclose(c, y - z) # index out of range error stream = source.select([0, 1, 2]) with pytest.raises(IndexError, match='.* index out of range'): for _ in stream: pass
def test_override_random_state(self): rs0 = np.random.RandomState(1234) x = rs0.normal(size=[5, 1]) y = rs0.normal(size=[5, 2]) z = rs0.normal(size=[5, 3]) rs = np.random.RandomState(1234) stream_x = DataStream.arrays([x], batch_size=3, random_state=rs) stream_yz = DataStream.arrays([y, z], batch_size=3) # test overriding random state rs2 = np.random.RandomState(1234) stream = GatherDataStream([stream_x, stream_yz], random_state=rs2) self.assertIs(stream.random_state, rs2) # test copy with overrided random state rs3 = np.random.RandomState(1234) stream2 = stream.copy(random_state=rs3) self.assertIsInstance(stream2, GatherDataStream) self.assertIs(stream2.random_state, rs3)
def test_to_arrays_stream(self): x = np.random.normal(size=[5, 4]) y = np.random.normal(size=[5, 2, 3]) stream = DataStream.arrays([x, y], batch_size=3) self.assertIsNone(stream.random_state) stream2 = stream.to_arrays_stream() self.assertIsNone(stream2.random_state) rs = np.random.RandomState() stream3 = stream.to_arrays_stream(random_state=rs) self.assertIs(stream3.random_state, rs)
def test_stream(self): np.random.seed(1234) rs = np.random.RandomState() x = np.random.normal(size=[5, 1]) y = np.random.normal(size=[5, 2]) source = DataStream.arrays([x, y], batch_size=3, random_state=rs) # test argument validation with pytest.raises(TypeError, match='`source` is not a DataStream: <object.*>'): _ = ThreadingDataStream(object(), prefetch=2) with pytest.raises(ValueError, match='`prefetch` must be at least 1'): _ = ThreadingDataStream(source, prefetch=0) # test threaded with context stream = ThreadingDataStream(source, prefetch=3) self.assertIs(stream.source, source) self.assertEqual(stream.prefetch, 3) self.assertEqual(stream.batch_size, 3) self.assertEqual(stream.array_count, 2) self.assertIs(stream.random_state, rs) self.assertEqual(stream.data_length, 5) self.assertEqual(stream.data_shapes, ((1,), (2,))) self.assertFalse(stream._worker_alive) self.assertFalse(stream._initialized) with stream: arrays = stream.get_arrays() self.assertEqual(len(arrays), 2) np.testing.assert_equal(arrays[0], x) np.testing.assert_equal(arrays[1], y) self.assertFalse(stream._worker_alive) self.assertFalse(stream._initialized) # test threaded without context stream = ThreadingDataStream(source, prefetch=3) self.assertFalse(stream._worker_alive) self.assertFalse(stream._initialized) arrays = stream.get_arrays() self.assertTrue(stream._worker_alive) self.assertTrue(stream._initialized) stream.close() self.assertFalse(stream._worker_alive) self.assertFalse(stream._initialized) stream.close() # double close should not cause an error self.assertEqual(len(arrays), 2) np.testing.assert_equal(arrays[0], x) np.testing.assert_equal(arrays[1], y)
def test_stream(self): def identity(*args): return args # test argument validation with pytest.raises(TypeError, match='`source` is not a DataStream: <object.*>'): _ = MapperDataStream(object(), identity) # test property inheritance np.random.seed(1234) rs = np.random.RandomState() x = np.random.normal(size=[5, 1]) y = np.random.normal(size=[5, 2]) source = DataStream.arrays([x, y], batch_size=3, random_state=rs) mapped = MapperDataStream(source, identity, preserve_shapes=False) self.assertIs(mapped.source, source) self.assertEqual(mapped.data_length, 5) self.assertEqual(mapped.batch_size, 3) self.assertEqual(mapped.batch_count, 2) self.assertIsNone(mapped.array_count) self.assertIsNone(mapped.data_shapes) self.assertIs(mapped.random_state, rs) mapped = MapperDataStream(source, identity, preserve_shapes=True) self.assertEqual(mapped.data_length, 5) self.assertEqual(mapped.batch_size, 3) self.assertEqual(mapped.batch_count, 2) self.assertEqual(mapped.array_count, 2) self.assertEqual(mapped.data_shapes, ((1, ), (2, ))) self.assertIs(mapped.random_state, rs) # test override rs2 = np.random.RandomState() mapped = MapperDataStream( source, identity, # in fact, these overrides are incorrect batch_size=7, array_count=1, data_shapes=((3, ), ), data_length=11, random_state=rs2, ) self.assertEqual(mapped.data_length, 11) self.assertEqual(mapped.batch_size, 7) self.assertEqual(mapped.batch_count, 2) self.assertEqual(mapped.array_count, 1) self.assertEqual(mapped.data_shapes, ((3, ), )) self.assertIs(mapped.random_state, rs2)
def test_stream(self): # argument validation with pytest.raises(ValueError, match='At least one data stream should be ' 'specified'): _ = GatherDataStream([]) with pytest.raises(TypeError, match='The 1-th element of `streams` is not ' 'an instance of DataStream: <object.*>'): _ = GatherDataStream([DataStream.int_seq(5, batch_size=3), object()]) def my_generator(): if False: yield with pytest.raises(ValueError, match='Inconsistent batch size among the specified ' 'streams: encountered 4 at the 3-th stream, ' 'but has already encountered 3 before.'): _ = GatherDataStream([ DataStream.generator(my_generator), DataStream.int_seq(5, batch_size=3), DataStream.generator(my_generator), DataStream.int_seq(5, batch_size=4), ]) with pytest.raises(ValueError, match='Inconsistent data length among the specified ' 'streams: encountered 6 at the 3-th stream, ' 'but has already encountered 5 before.'): _ = GatherDataStream([ DataStream.generator(my_generator), DataStream.int_seq(5, batch_size=3), DataStream.generator(my_generator), DataStream.int_seq(6, batch_size=3), ]) # test property inheritance rs0 = np.random.RandomState(1234) x = rs0.normal(size=[5, 1]) y = rs0.normal(size=[5, 2]) z = rs0.normal(size=[5, 3]) rs = np.random.RandomState(1234) stream_x = DataStream.arrays([x], batch_size=3, random_state=rs) stream_yz = DataStream.arrays([y, z], batch_size=3) stream = GatherDataStream([stream_x, stream_yz]) self.assertTupleEqual(stream.streams, (stream_x, stream_yz)) self.assertEqual(stream.batch_size, 3) self.assertEqual(stream.data_shapes, ((1,), (2,), (3,))) self.assertEqual(stream.data_length, 5) self.assertEqual(stream.array_count, 3) self.assertIs(stream.random_state, rs) arrays = stream.get_arrays() self.assertEqual(len(arrays), 3) np.testing.assert_equal(arrays[0], x) np.testing.assert_equal(arrays[1], y) np.testing.assert_equal(arrays[2], z) # test no property to inherit stream_1 = DataStream.generator(my_generator) stream_2 = DataStream.generator(my_generator) stream = GatherDataStream([stream_1, stream_2]) for attr in ('batch_size', 'array_count', 'data_length', 'data_shapes', 'random_state'): self.assertIsNone(getattr(stream, attr)) stream = GatherDataStream([stream_1, stream_2], random_state=rs) self.assertIs(stream.random_state, rs)