def _test_epoch_idx(batch_size, epoch_size, cb, batch_info, batch_mode): num_epochs = 3 pipe = Pipeline(batch_size, 1, 0) with pipe: ext = fn.external_source(source=cb, batch_info=batch_info, batch=batch_mode) pipe.set_outputs(ext) pipe.build() for epoch_idx in range(num_epochs): for iteration in range(epoch_size): (batch, ) = pipe.run() assert len(batch) == batch_size for sample_i, sample in enumerate(batch): if batch_mode: expected = np.array( [iteration, epoch_idx if batch_info else -1]) else: expected = np.array([ iteration * batch_size + sample_i, sample_i, iteration, epoch_idx ]) np.testing.assert_array_equal(sample, expected) try: pipe.run() except: pipe.reset() else: assert False, "expected StopIteration"
def test_external_source_collection_cycling_raise(): pipe = Pipeline(1, 3, 0, prefetch_queue_depth=1) batches = [ [make_array([1.5,2.5], dtype=datapy.float32)], [make_array([-1, 3.5,4.5], dtype=datapy.float32)] ] def batch_gen(): for b in batches: yield b pipe.set_outputs(fn.external_source(batches, cycle = "raise"), fn.external_source(batch_gen, cycle = "raise")) pipe.build() # epochs are cycles over the source iterable for _ in range(3): for batch in batches: pipe_out = pipe.run() batch = asnumpy(batch) batch = batch, batch check_output(pipe_out, batch) with assert_raises(StopIteration): pipe.run() pipe.reset()