Ejemplo n.º 1
0
    def test_select(self):
        # additional tests when source stream does not report array_count
        # and data_shapes

        def mapper(x, y, z):
            return x + y, y - z

        x = np.random.normal(size=[5, 3])
        y = np.random.normal(size=[5, 1])
        z = np.random.normal(size=[5, 3])

        source = DataStream.arrays([x, y, z], batch_size=3).map(mapper)
        self.assertIsNone(source.data_shapes)
        self.assertIsNone(source.array_count)

        stream = source.select([-1, 0, 1])
        self.assertEqual(stream.array_count, 3)
        self.assertIsNone(stream.data_shapes)

        a, b, c = stream.get_arrays()
        np.testing.assert_allclose(a, y - z)
        np.testing.assert_allclose(b, x + y)
        np.testing.assert_allclose(c, y - z)

        # index out of range error
        stream = source.select([0, 1, 2])
        with pytest.raises(IndexError, match='.* index out of range'):
            for _ in stream:
                pass
Ejemplo n.º 2
0
    def test_override_random_state(self):
        rs0 = np.random.RandomState(1234)
        x = rs0.normal(size=[5, 1])
        y = rs0.normal(size=[5, 2])
        z = rs0.normal(size=[5, 3])

        rs = np.random.RandomState(1234)
        stream_x = DataStream.arrays([x], batch_size=3, random_state=rs)
        stream_yz = DataStream.arrays([y, z], batch_size=3)

        # test overriding random state
        rs2 = np.random.RandomState(1234)
        stream = GatherDataStream([stream_x, stream_yz], random_state=rs2)
        self.assertIs(stream.random_state, rs2)

        # test copy with overrided random state
        rs3 = np.random.RandomState(1234)
        stream2 = stream.copy(random_state=rs3)
        self.assertIsInstance(stream2, GatherDataStream)
        self.assertIs(stream2.random_state, rs3)
Ejemplo n.º 3
0
    def test_to_arrays_stream(self):
        x = np.random.normal(size=[5, 4])
        y = np.random.normal(size=[5, 2, 3])
        stream = DataStream.arrays([x, y], batch_size=3)
        self.assertIsNone(stream.random_state)

        stream2 = stream.to_arrays_stream()
        self.assertIsNone(stream2.random_state)

        rs = np.random.RandomState()
        stream3 = stream.to_arrays_stream(random_state=rs)
        self.assertIs(stream3.random_state, rs)
Ejemplo n.º 4
0
    def test_stream(self):
        np.random.seed(1234)
        rs = np.random.RandomState()
        x = np.random.normal(size=[5, 1])
        y = np.random.normal(size=[5, 2])
        source = DataStream.arrays([x, y], batch_size=3, random_state=rs)

        # test argument validation
        with pytest.raises(TypeError,
                           match='`source` is not a DataStream: <object.*>'):
            _ = ThreadingDataStream(object(), prefetch=2)

        with pytest.raises(ValueError,
                           match='`prefetch` must be at least 1'):
            _ = ThreadingDataStream(source, prefetch=0)

        # test threaded with context
        stream = ThreadingDataStream(source, prefetch=3)
        self.assertIs(stream.source, source)
        self.assertEqual(stream.prefetch, 3)
        self.assertEqual(stream.batch_size, 3)
        self.assertEqual(stream.array_count, 2)
        self.assertIs(stream.random_state, rs)
        self.assertEqual(stream.data_length, 5)
        self.assertEqual(stream.data_shapes, ((1,), (2,)))

        self.assertFalse(stream._worker_alive)
        self.assertFalse(stream._initialized)
        with stream:
            arrays = stream.get_arrays()
            self.assertEqual(len(arrays), 2)
            np.testing.assert_equal(arrays[0], x)
            np.testing.assert_equal(arrays[1], y)
        self.assertFalse(stream._worker_alive)
        self.assertFalse(stream._initialized)

        # test threaded without context
        stream = ThreadingDataStream(source, prefetch=3)
        self.assertFalse(stream._worker_alive)
        self.assertFalse(stream._initialized)
        arrays = stream.get_arrays()
        self.assertTrue(stream._worker_alive)
        self.assertTrue(stream._initialized)
        stream.close()
        self.assertFalse(stream._worker_alive)
        self.assertFalse(stream._initialized)
        stream.close()  # double close should not cause an error

        self.assertEqual(len(arrays), 2)
        np.testing.assert_equal(arrays[0], x)
        np.testing.assert_equal(arrays[1], y)
Ejemplo n.º 5
0
    def test_stream(self):
        def identity(*args):
            return args

        # test argument validation
        with pytest.raises(TypeError,
                           match='`source` is not a DataStream: <object.*>'):
            _ = MapperDataStream(object(), identity)

        # test property inheritance
        np.random.seed(1234)
        rs = np.random.RandomState()
        x = np.random.normal(size=[5, 1])
        y = np.random.normal(size=[5, 2])
        source = DataStream.arrays([x, y], batch_size=3, random_state=rs)

        mapped = MapperDataStream(source, identity, preserve_shapes=False)
        self.assertIs(mapped.source, source)
        self.assertEqual(mapped.data_length, 5)
        self.assertEqual(mapped.batch_size, 3)
        self.assertEqual(mapped.batch_count, 2)
        self.assertIsNone(mapped.array_count)
        self.assertIsNone(mapped.data_shapes)
        self.assertIs(mapped.random_state, rs)

        mapped = MapperDataStream(source, identity, preserve_shapes=True)
        self.assertEqual(mapped.data_length, 5)
        self.assertEqual(mapped.batch_size, 3)
        self.assertEqual(mapped.batch_count, 2)
        self.assertEqual(mapped.array_count, 2)
        self.assertEqual(mapped.data_shapes, ((1, ), (2, )))
        self.assertIs(mapped.random_state, rs)

        # test override
        rs2 = np.random.RandomState()
        mapped = MapperDataStream(
            source,
            identity,
            # in fact, these overrides are incorrect
            batch_size=7,
            array_count=1,
            data_shapes=((3, ), ),
            data_length=11,
            random_state=rs2,
        )
        self.assertEqual(mapped.data_length, 11)
        self.assertEqual(mapped.batch_size, 7)
        self.assertEqual(mapped.batch_count, 2)
        self.assertEqual(mapped.array_count, 1)
        self.assertEqual(mapped.data_shapes, ((3, ), ))
        self.assertIs(mapped.random_state, rs2)
Ejemplo n.º 6
0
    def test_stream(self):
        # argument validation
        with pytest.raises(ValueError,
                           match='At least one data stream should be '
                                 'specified'):
            _ = GatherDataStream([])

        with pytest.raises(TypeError,
                           match='The 1-th element of `streams` is not '
                                 'an instance of DataStream: <object.*>'):
            _ = GatherDataStream([DataStream.int_seq(5, batch_size=3),
                                  object()])

        def my_generator():
            if False:
                yield

        with pytest.raises(ValueError,
                           match='Inconsistent batch size among the specified '
                                 'streams: encountered 4 at the 3-th stream, '
                                 'but has already encountered 3 before.'):
            _ = GatherDataStream([
                DataStream.generator(my_generator),
                DataStream.int_seq(5, batch_size=3),
                DataStream.generator(my_generator),
                DataStream.int_seq(5, batch_size=4),
            ])

        with pytest.raises(ValueError,
                           match='Inconsistent data length among the specified '
                                 'streams: encountered 6 at the 3-th stream, '
                                 'but has already encountered 5 before.'):
            _ = GatherDataStream([
                DataStream.generator(my_generator),
                DataStream.int_seq(5, batch_size=3),
                DataStream.generator(my_generator),
                DataStream.int_seq(6, batch_size=3),
            ])

        # test property inheritance
        rs0 = np.random.RandomState(1234)
        x = rs0.normal(size=[5, 1])
        y = rs0.normal(size=[5, 2])
        z = rs0.normal(size=[5, 3])

        rs = np.random.RandomState(1234)
        stream_x = DataStream.arrays([x], batch_size=3, random_state=rs)
        stream_yz = DataStream.arrays([y, z], batch_size=3)

        stream = GatherDataStream([stream_x, stream_yz])
        self.assertTupleEqual(stream.streams, (stream_x, stream_yz))
        self.assertEqual(stream.batch_size, 3)
        self.assertEqual(stream.data_shapes, ((1,), (2,), (3,)))
        self.assertEqual(stream.data_length, 5)
        self.assertEqual(stream.array_count, 3)
        self.assertIs(stream.random_state, rs)

        arrays = stream.get_arrays()
        self.assertEqual(len(arrays), 3)
        np.testing.assert_equal(arrays[0], x)
        np.testing.assert_equal(arrays[1], y)
        np.testing.assert_equal(arrays[2], z)

        # test no property to inherit
        stream_1 = DataStream.generator(my_generator)
        stream_2 = DataStream.generator(my_generator)

        stream = GatherDataStream([stream_1, stream_2])
        for attr in ('batch_size', 'array_count', 'data_length', 'data_shapes',
                     'random_state'):
            self.assertIsNone(getattr(stream, attr))

        stream = GatherDataStream([stream_1, stream_2], random_state=rs)
        self.assertIs(stream.random_state, rs)