def test_auto_init(self):
        epoch_counter = [0]

        seq_flow = DataFlow.seq(0, 10, batch_size=2)
        map_flow = seq_flow.map(
            lambda x: (x + epoch_counter[0] * 10,))

        def make_iterator():
            epoch_counter[0] += 1
            return map_flow

        it_flow = DataFlow.iterator_factory(make_iterator)
        flow = it_flow.threaded(3)

        batches = [b[0] for b in flow]
        np.testing.assert_array_equal(
            [[10, 11], [12, 13], [14, 15], [16, 17], [18, 19]], batches)

        batches = [b[0] for b in flow]
        np.testing.assert_array_equal(
            [[20, 21], [22, 23], [24, 25], [26, 27], [28, 29]], batches)

        flow.close()
        batches = [b[0] for b in flow]
        np.testing.assert_array_equal(
            [[40, 41], [42, 43], [44, 45], [46, 47], [48, 49]], batches)

        flow.close()
Example #2
0
    def iter_steps(self, data_generator=None):
        """
        Iterate through the steps.

        This method can only be called when there's no other step loop
        is being iterated, and an epoch loop is active.

        Args:
            data_generator: Optional iterable data to be yielded at every step.
                This is required if `max_step` is not configured, so as to
                prevent an infinite step loop.

        Yields:
            int or (int, any): The global step counter (starting from 1), or
                the tuple of ``(step counter, batch data)`` if `data_generator`
                is specified.
        """
        def loop_condition():
            return self._max_step is None or self.step < self._max_step

        self._require_entered()
        if not self._within_epoch:
            raise RuntimeError('Step loop must be opened within active epoch '
                               'loop')
        if self._within_step:
            raise RuntimeError('Another step loop has been opened')
        if self._max_step is None and data_generator is None:
            raise RuntimeError('`data_generator` is required when `max_step` '
                               'is not configured, so as to prevent an '
                               'unstoppable step loop')

        try:
            if data_generator is not None:
                if isinstance(data_generator, DataFlow):
                    data_flow = data_generator
                else:

                    def iter_factory():
                        if data_gen[0] is not None:
                            for batch in data_gen[0]:
                                yield batch
                        data_gen[0] = None  # force to use data_generator once

                    data_gen = [data_generator]
                    data_flow = DataFlow.iterator_factory(iter_factory)
                self._data_flow = data_flow

            while loop_condition():
                # prepare for the step data
                if self._data_flow is None:
                    yield_obj = self.step + 1
                    step_data = None
                else:
                    try:
                        step_data = self._data_flow.next_batch()
                    except StopIteration:
                        break
                    yield_obj = self.step + 1, step_data

                # yield this step
                self._states.step += 1
                self._within_step = True
                self._step_data = step_data
                self._step_start_time = time.time()

                self.events.fire(EventKeys.BEFORE_STEP, self)
                try:
                    yield yield_obj
                except StopIteration:  # pragma: no cover
                    # might be caused by call to ``data_flow.next_batch()``
                    break
                self.events.reverse_fire(EventKeys.AFTER_STEP, self)

                self._commit_step_stop_time()
        finally:
            self._within_step = False
            self._step_start_time = None
            self._data_flow = None
            self._step_data = None
    def test_iterator(self):
        epoch_counter = [0]
        external_counter = [1]

        seq_flow = DataFlow.seq(0, 10, batch_size=2)
        map_flow = seq_flow.map(
            lambda x: (x + epoch_counter[0] * 10 + external_counter[0] * 100,))

        def make_iterator():
            epoch_counter[0] += 1
            return map_flow

        it_flow = DataFlow.iterator_factory(make_iterator)
        with it_flow.threaded(prefetch=2) as flow:
            # the first epoch, expect 0 .. 10
            np.testing.assert_array_equal(
                [[110, 111], [112, 113], [114, 115], [116, 117], [118, 119]],
                [a[0] for a in flow]
            )
            time.sleep(.1)
            external_counter[0] += 1

            # the second epoch, the epoch counter should affect more than
            # the external counter
            np.testing.assert_array_equal(
                # having `prefetch = 2` should affect 3 items, because
                # while the queue size is 2, there are 1 additional prefetched
                # item waiting to be enqueued
                [[120, 121], [122, 123], [124, 125], [226, 227], [228, 229]],
                [a[0] for a in flow]
            )
            time.sleep(.1)
            external_counter[0] += 1

            # the third epoch, we shall carry out an incomplete epoch by break
            for a in flow:
                np.testing.assert_array_equal([230, 231], a[0])
                break
            time.sleep(.1)
            external_counter[0] += 1

            # verify that the epoch counter increases after break
            for i, (a,) in enumerate(flow):
                # because the interruption is not well-predictable under
                # multi-threading context, we shall have a weaker verification
                # than the above
                self.assertTrue((340 + i * 2 == a[0]) or (440 + i * 2 == a[0]))
                self.assertTrue((341 + i * 2 == a[1]) or (441 + i * 2 == a[1]))
            time.sleep(.1)
            external_counter[0] += 1

            # carry out the fourth, incomplete epoch by error
            try:
                for a in flow:
                    np.testing.assert_array_equal([450, 451], a[0])
                    raise _MyError()
            except _MyError:
                pass
            time.sleep(.1)
            external_counter[0] += 1

            # verify that the epoch counter increases after error
            for i, (a,) in enumerate(flow):
                self.assertTrue((560 + i * 2 == a[0]) or (660 + i * 2 == a[0]))
                self.assertTrue((561 + i * 2 == a[1]) or (661 + i * 2 == a[1]))