def _test_epoch_idx(batch_size, epoch_size, cb, batch_info, batch_mode):
    num_epochs = 3
    pipe = Pipeline(batch_size, 1, 0)
    with pipe:
        ext = fn.external_source(source=cb,
                                 batch_info=batch_info,
                                 batch=batch_mode)
        pipe.set_outputs(ext)
    pipe.build()
    for epoch_idx in range(num_epochs):
        for iteration in range(epoch_size):
            (batch, ) = pipe.run()
            assert len(batch) == batch_size
            for sample_i, sample in enumerate(batch):
                if batch_mode:
                    expected = np.array(
                        [iteration, epoch_idx if batch_info else -1])
                else:
                    expected = np.array([
                        iteration * batch_size + sample_i, sample_i, iteration,
                        epoch_idx
                    ])
                np.testing.assert_array_equal(sample, expected)
        try:
            pipe.run()
        except:
            pipe.reset()
        else:
            assert False, "expected StopIteration"
def test_external_source_collection_cycling_raise():
    pipe = Pipeline(1, 3, 0, prefetch_queue_depth=1)

    batches = [
        [make_array([1.5,2.5], dtype=datapy.float32)],
        [make_array([-1, 3.5,4.5], dtype=datapy.float32)]
    ]

    def batch_gen():
        for b in batches:
            yield b

    pipe.set_outputs(fn.external_source(batches, cycle = "raise"), fn.external_source(batch_gen, cycle = "raise"))
    pipe.build()

    # epochs are cycles over the source iterable
    for _ in range(3):
        for batch in batches:
            pipe_out = pipe.run()
            batch = asnumpy(batch)
            batch = batch, batch
            check_output(pipe_out, batch)

        with assert_raises(StopIteration):
            pipe.run()
        pipe.reset()