def check_callback(parallel_pipe, pipe, epoch_size, batch_size, dtype=None): iters_no = epoch_size // batch_size parallel_pipe.build() pipe.build() capture_processes(parallel_pipe._py_pool) compare_pipelines(parallel_pipe, pipe, batch_size, iters_no) parallel_pipe._py_pool.close()
def build_and_run_pipeline(pipe, iters=None, *args): pipe.build() capture_processes(pipe._py_pool) if iters is None: while True: pipe.run() else: for _ in range(iters): pipe.run()
def check_layout(pipe, layout): pipe.build() capture_processes(pipe._py_pool) while True: try: (res,) = pipe.run() assert res.layout() == layout except StopIteration: break
def create_pool(groups, keep_alive_queue_size=1, num_workers=1, start_method="fork"): pool = WorkerPool.from_groups(groups, keep_alive_queue_size, start_method=start_method, num_workers=num_workers) try: capture_processes(pool) return closing(pool) except Exception: pool.close() raise
def check_stop_iteration_resume(pipe, batch_size, layout): pipe.build() capture_processes(pipe._py_pool) outputs_epoch_1, outputs_epoch_2 = [], [] for output in [outputs_epoch_1, outputs_epoch_2]: try: while True: (r,) = pipe.run() r = [np.copy(r.at(i)) for i in range(len(r))] output.append(r) except StopIteration: pipe.reset() assert len(outputs_epoch_1) == len(outputs_epoch_2), ( "Epochs must have same number of iterations, " "but they have {} {} respectively".format(len(outputs_epoch_1), len(outputs_epoch_2))) for out_1, out_2 in zip(outputs_epoch_1, outputs_epoch_2): check_batch(out_1, out_2, batch_size, 0, None, expected_layout=layout, compare_layouts=True)
def test_pool_context_sync(start_method): callbacks = [simple_callback, another_callback] groups = [ MockGroup.from_callback(cb, prefetch_queue_depth=3) for cb in callbacks ] with create_pool(groups, keep_alive_queue_size=1, num_workers=4, start_method=start_method) as pool: capture_processes(pool) for i in range(4): tasks = [(SampleInfo(j, 0, 0, 0), ) for j in range(10 * (i + 1))] work_batch = TaskArgs.make_sample( SampleRange(0, 10 * (i + 1), 0, 0)) pool.schedule_batch(context_i=0, work_batch=work_batch) pool.schedule_batch(context_i=1, work_batch=work_batch) assert_scheduled_num(pool.contexts[0], 4) assert_scheduled_num(pool.contexts[1], 4) # pool after a reset should discard all previously scheduled tasks # (and sync workers to avoid race on writing to results buffer) pool.reset() tasks = [(SampleInfo(1000 + j, j, 0, 1), ) for j in range(5)] work_batch = TaskArgs.make_sample(SampleRange(1000, 1005, 0, 1)) pool.schedule_batch(context_i=0, work_batch=work_batch) pool.schedule_batch(context_i=1, work_batch=work_batch) assert_scheduled_num(pool.contexts[0], 1) assert_scheduled_num(pool.contexts[1], 1) batch_0 = pool.receive_batch(context_i=0) batch_1 = pool.receive_batch(context_i=1) assert len(batch_0) == len(tasks) assert len(batch_1) == len(tasks) for task, sample in zip(tasks, batch_0): np.testing.assert_array_equal(answer(-1, *task)[1:], sample[1:]) for task, sample in zip(tasks, batch_1): np.testing.assert_array_equal( answer(-1, *task)[1:] + 100, sample[1:])