예제 #1
0
    def test_auto_init(self):
        epoch_counter = [0]

        seq_flow = DataStream.int_seq(0, 10, batch_size=2)
        map_flow = seq_flow.map(
            lambda x: (x + epoch_counter[0] * 10,))

        def make_iterator():
            epoch_counter[0] += 1
            return map_flow

        it_flow = DataStream.generator(make_iterator)
        flow = it_flow.threaded(3)

        batches = [b[0] for b in flow]
        np.testing.assert_array_equal(
            [[10, 11], [12, 13], [14, 15], [16, 17], [18, 19]], batches)

        batches = [b[0] for b in flow]
        np.testing.assert_array_equal(
            [[20, 21], [22, 23], [24, 25], [26, 27], [28, 29]], batches)

        flow.close()
        batches = [b[0] for b in flow]
        np.testing.assert_array_equal(
            [[40, 41], [42, 43], [44, 45], [46, 47], [48, 49]], batches)

        flow.close()
예제 #2
0
    def test_copy(self):
        source = DataStream.int_seq(5, batch_size=3)
        stream = source.threaded(3)
        self.assertIs(stream.source, source)
        self.assertEqual(stream.prefetch, 3)

        stream2 = stream.copy(prefetch=1)
        self.assertIsInstance(stream2, ThreadingDataStream)
        self.assertIs(stream2.source, source)
        self.assertEqual(stream2.prefetch, 1)
예제 #3
0
    def test_iter_reentrant_warn(self):
        stream = DataStream.int_seq(5, batch_size=3)

        # test open and close, no warning
        iterator = iter(stream)
        np.testing.assert_equal(next(iterator)[0], [0, 1, 2])
        iterator.close()

        with warnings.catch_warnings(record=True) as w:
            warnings.simplefilter('always')
            batches = list(stream)
            self.assertEqual(len(batches), 2)
            np.testing.assert_equal(batches[0][0], [0, 1, 2])
            np.testing.assert_equal(batches[1][0], [3, 4])

        self.assertEqual(len(w), 0)

        # test open without close, cause warning
        iterator = iter(stream)
        np.testing.assert_equal(next(iterator)[0], [0, 1, 2])

        with warnings.catch_warnings(record=True) as w:
            warnings.simplefilter('always')
            batches = list(stream)
            self.assertEqual(len(batches), 2)
            np.testing.assert_equal(batches[0][0], [0, 1, 2])
            np.testing.assert_equal(batches[1][0], [3, 4])

        self.assertEqual(len(w), 1)
        self.assertTrue(issubclass(w[-1].category, UserWarning))
        self.assertRegex(
            str(w[-1].message),
            r'Another iterator of the DataStream .* is still active, '
            r'will close it automatically.')

        with pytest.raises(StopIteration):
            _ = next(iterator)  # this iterator should have been closed

        # test no warning the second time
        iterator = iter(stream)
        np.testing.assert_equal(next(iterator)[0], [0, 1, 2])

        with warnings.catch_warnings(record=True) as w:
            warnings.simplefilter('always')
            batches = list(stream)
            self.assertEqual(len(batches), 2)
            np.testing.assert_equal(batches[0][0], [0, 1, 2])
            np.testing.assert_equal(batches[1][0], [3, 4])

        self.assertEqual(len(w), 0)
예제 #4
0
    def test_copy(self):
        rs = np.random.RandomState()
        source = DataStream.int_seq(5, batch_size=3)

        mapped = MapperDataStream(source, lambda *args: args)
        mapped2 = mapped.copy(batch_size=7,
                              array_count=1,
                              data_shapes=((3, ), ),
                              data_length=11,
                              random_state=rs)
        self.assertEqual(mapped2.data_length, 11)
        self.assertEqual(mapped2.batch_size, 7)
        self.assertEqual(mapped2.array_count, 1)
        self.assertEqual(mapped2.data_shapes, ((3, ), ))
        self.assertIs(mapped2.random_state, rs)
예제 #5
0
    def test_iterator(self):
        class _MyError(Exception):
            pass

        epoch_counter = [0]
        external_counter = [1]

        seq_flow = DataStream.int_seq(0, 10, batch_size=2)
        map_flow = seq_flow.map(
            lambda x: (x + epoch_counter[0] * 10 + external_counter[0] * 100,))

        def make_iterator():
            epoch_counter[0] += 1
            return map_flow

        it_flow = DataStream.generator(make_iterator)
        with it_flow.threaded(prefetch=2) as flow:
            # the first epoch, expect 0 .. 10
            np.testing.assert_array_equal(
                [[110, 111], [112, 113], [114, 115], [116, 117], [118, 119]],
                [a[0] for a in flow]
            )
            time.sleep(.1)
            external_counter[0] += 1

            # the second epoch, the epoch counter should affect more than
            # the external counter
            np.testing.assert_array_equal(
                # having `prefetch = 2` should affect 3 items, because
                # while the queue size is 2, there are 1 additional prefetched
                # item waiting to be enqueued
                [[120, 121], [122, 123], [124, 125], [226, 227], [228, 229]],
                [a[0] for a in flow]
            )
            time.sleep(.1)
            external_counter[0] += 1

            # the third epoch, we shall carry out an incomplete epoch by break
            for a in flow:
                np.testing.assert_array_equal([230, 231], a[0])
                break
            time.sleep(.1)
            external_counter[0] += 1

            # verify that the epoch counter increases after break
            for i, (a,) in enumerate(flow):
                # because the interruption is not well-predictable under
                # multi-threading context, we shall have a weaker verification
                # than the above
                self.assertTrue((340 + i * 2 == a[0]) or (440 + i * 2 == a[0]))
                self.assertTrue((341 + i * 2 == a[1]) or (441 + i * 2 == a[1]))
            time.sleep(.1)
            external_counter[0] += 1

            # carry out the fourth, incomplete epoch by error
            try:
                for a in flow:
                    np.testing.assert_array_equal([450, 451], a[0])
                    raise _MyError()
            except _MyError:
                pass
            time.sleep(.1)
            external_counter[0] += 1

            # verify that the epoch counter increases after error
            for i, (a,) in enumerate(flow):
                self.assertTrue((560 + i * 2 == a[0]) or (660 + i * 2 == a[0]))
                self.assertTrue((561 + i * 2 == a[1]) or (661 + i * 2 == a[1]))
예제 #6
0
    def test_stream(self):
        # argument validation
        with pytest.raises(ValueError,
                           match='At least one data stream should be '
                                 'specified'):
            _ = GatherDataStream([])

        with pytest.raises(TypeError,
                           match='The 1-th element of `streams` is not '
                                 'an instance of DataStream: <object.*>'):
            _ = GatherDataStream([DataStream.int_seq(5, batch_size=3),
                                  object()])

        def my_generator():
            if False:
                yield

        with pytest.raises(ValueError,
                           match='Inconsistent batch size among the specified '
                                 'streams: encountered 4 at the 3-th stream, '
                                 'but has already encountered 3 before.'):
            _ = GatherDataStream([
                DataStream.generator(my_generator),
                DataStream.int_seq(5, batch_size=3),
                DataStream.generator(my_generator),
                DataStream.int_seq(5, batch_size=4),
            ])

        with pytest.raises(ValueError,
                           match='Inconsistent data length among the specified '
                                 'streams: encountered 6 at the 3-th stream, '
                                 'but has already encountered 5 before.'):
            _ = GatherDataStream([
                DataStream.generator(my_generator),
                DataStream.int_seq(5, batch_size=3),
                DataStream.generator(my_generator),
                DataStream.int_seq(6, batch_size=3),
            ])

        # test property inheritance
        rs0 = np.random.RandomState(1234)
        x = rs0.normal(size=[5, 1])
        y = rs0.normal(size=[5, 2])
        z = rs0.normal(size=[5, 3])

        rs = np.random.RandomState(1234)
        stream_x = DataStream.arrays([x], batch_size=3, random_state=rs)
        stream_yz = DataStream.arrays([y, z], batch_size=3)

        stream = GatherDataStream([stream_x, stream_yz])
        self.assertTupleEqual(stream.streams, (stream_x, stream_yz))
        self.assertEqual(stream.batch_size, 3)
        self.assertEqual(stream.data_shapes, ((1,), (2,), (3,)))
        self.assertEqual(stream.data_length, 5)
        self.assertEqual(stream.array_count, 3)
        self.assertIs(stream.random_state, rs)

        arrays = stream.get_arrays()
        self.assertEqual(len(arrays), 3)
        np.testing.assert_equal(arrays[0], x)
        np.testing.assert_equal(arrays[1], y)
        np.testing.assert_equal(arrays[2], z)

        # test no property to inherit
        stream_1 = DataStream.generator(my_generator)
        stream_2 = DataStream.generator(my_generator)

        stream = GatherDataStream([stream_1, stream_2])
        for attr in ('batch_size', 'array_count', 'data_length', 'data_shapes',
                     'random_state'):
            self.assertIsNone(getattr(stream, attr))

        stream = GatherDataStream([stream_1, stream_2], random_state=rs)
        self.assertIs(stream.random_state, rs)