def test_auto_init(self): epoch_counter = [0] seq_flow = DataFlow.seq(0, 10, batch_size=2) map_flow = seq_flow.map(lambda x: (x + epoch_counter[0] * 10, )) def make_iterator(): epoch_counter[0] += 1 return map_flow it_flow = DataFlow.iterator_factory(make_iterator) flow = it_flow.threaded(3) batches = [b[0] for b in flow] np.testing.assert_array_equal( [[10, 11], [12, 13], [14, 15], [16, 17], [18, 19]], batches) batches = [b[0] for b in flow] np.testing.assert_array_equal( [[20, 21], [22, 23], [24, 25], [26, 27], [28, 29]], batches) flow.close() batches = [b[0] for b in flow] np.testing.assert_array_equal( [[40, 41], [42, 43], [44, 45], [46, 47], [48, 49]], batches) flow.close()
def test_iterator_factory(self): x_flow = DataFlow.arrays([np.arange(5)], batch_size=3) y_flow = DataFlow.arrays([np.arange(5, 10)], batch_size=3) flow = DataFlow.iterator_factory( lambda: ((x, y) for (x, ), (y, ) in zip(x_flow, y_flow))) b = list(flow) self.assertEquals(2, len(b)) self.assertEquals(2, len(b[0])) np.testing.assert_array_equal([0, 1, 2], b[0][0]) np.testing.assert_array_equal([5, 6, 7], b[0][1]) np.testing.assert_array_equal([3, 4], b[1][0]) np.testing.assert_array_equal([8, 9], b[1][1])
def iter_steps(self, data_generator=None): """ Iterate through the steps. This method can only be called when there's no other step loop is being iterated, and an epoch loop is active. Args: data_generator: Optional iterable data to be yielded at every step. This is required if `max_step` is not configured, so as to prevent an infinite step loop. Yields: int or (int, any): The global step counter (starting from 1), or the tuple of ``(step counter, batch data)`` if `data_generator` is specified. """ def loop_condition(): return self._max_step is None or self._step < self._max_step self._require_entered() if not self._within_epoch: raise RuntimeError('Step loop must be opened within active epoch ' 'loop') if self._within_step: raise RuntimeError('Another step loop has been opened') if self._max_step is None and data_generator is None: raise RuntimeError('`data_generator` is required when `max_step` ' 'is not configured, so as to prevent an ' 'unstoppable step loop') try: if data_generator is not None: if isinstance(data_generator, DataFlow): data_flow = data_generator else: def iter_factory(): if data_gen[0] is not None: for batch in data_gen[0]: yield batch data_gen[0] = None # force to use data_generator once data_gen = [data_generator] data_flow = DataFlow.iterator_factory(iter_factory) self._data_flow = data_flow while loop_condition(): # prepare for the step data if self._data_flow is None: yield_obj = self._step + 1 else: try: step_data = self._data_flow.next_batch() except StopIteration: break yield_obj = self._step + 1, step_data # yield this step self._step += 1 self._within_step = True self._step_start_time = time.time() try: yield yield_obj except StopIteration: # pragma: no cover # might be caused by call to ``data_flow.next_batch()`` break self._commit_step_stop_time() finally: self._within_step = False self._step_start_time = None self._data_flow = None
def test_iterator(self): epoch_counter = [0] external_counter = [1] seq_flow = DataFlow.seq(0, 10, batch_size=2) map_flow = seq_flow.map(lambda x: (x + epoch_counter[0] * 10 + external_counter[0] * 100, )) def make_iterator(): epoch_counter[0] += 1 return map_flow it_flow = DataFlow.iterator_factory(make_iterator) with it_flow.threaded(prefetch=2) as flow: # the first epoch, expect 0 .. 10 np.testing.assert_array_equal( [[110, 111], [112, 113], [114, 115], [116, 117], [118, 119]], [a[0] for a in flow]) time.sleep(.1) external_counter[0] += 1 # the second epoch, the epoch counter should affect more than # the external counter np.testing.assert_array_equal( # having `prefetch = 2` should affect 3 items, because # while the queue size is 2, there are 1 additional prefetched # item waiting to be enqueued [[120, 121], [122, 123], [124, 125], [226, 227], [228, 229]], [a[0] for a in flow]) time.sleep(.1) external_counter[0] += 1 # the third epoch, we shall carry out an incomplete epoch by break for a in flow: np.testing.assert_array_equal([230, 231], a[0]) break time.sleep(.1) external_counter[0] += 1 # verify that the epoch counter increases after break for i, (a, ) in enumerate(flow): # because the interruption is not well-predictable under # multi-threading context, we shall have a weaker verification # than the above self.assertTrue((340 + i * 2 == a[0]) or (440 + i * 2 == a[0])) self.assertTrue((341 + i * 2 == a[1]) or (441 + i * 2 == a[1])) time.sleep(.1) external_counter[0] += 1 # carry out the fourth, incomplete epoch by error try: for a in flow: np.testing.assert_array_equal([450, 451], a[0]) raise _MyError() except _MyError: pass time.sleep(.1) external_counter[0] += 1 # verify that the epoch counter increases after error for i, (a, ) in enumerate(flow): self.assertTrue((560 + i * 2 == a[0]) or (660 + i * 2 == a[0])) self.assertTrue((561 + i * 2 == a[1]) or (661 + i * 2 == a[1]))