def test_pool_context_sync(start_method): callbacks = [simple_callback, another_callback] groups = [ MockGroup.from_callback(cb, prefetch_queue_depth=3) for cb in callbacks ] with create_pool(groups, keep_alive_queue_size=1, num_workers=4, start_method=start_method) as pool: capture_processes(pool) for i in range(4): tasks = [(SampleInfo(j, 0, 0, 0), ) for j in range(10 * (i + 1))] work_batch = TaskArgs.make_sample( SampleRange(0, 10 * (i + 1), 0, 0)) pool.schedule_batch(context_i=0, work_batch=work_batch) pool.schedule_batch(context_i=1, work_batch=work_batch) assert_scheduled_num(pool.contexts[0], 4) assert_scheduled_num(pool.contexts[1], 4) # pool after a reset should discard all previously scheduled tasks (and sync workers to avoid race on writing to results buffer) pool.reset() tasks = [(SampleInfo(1000 + j, j, 0, 1), ) for j in range(5)] work_batch = TaskArgs.make_sample(SampleRange(1000, 1005, 0, 1)) pool.schedule_batch(context_i=0, work_batch=work_batch) pool.schedule_batch(context_i=1, work_batch=work_batch) assert_scheduled_num(pool.contexts[0], 1) assert_scheduled_num(pool.contexts[1], 1) batch_0 = pool.receive_batch(context_i=0) batch_1 = pool.receive_batch(context_i=1) assert len(batch_0) == len(tasks) assert len(batch_1) == len(tasks) for task, sample in zip(tasks, batch_0): np.testing.assert_array_equal(answer(-1, *task)[1:], sample[1:]) for task, sample in zip(tasks, batch_1): np.testing.assert_array_equal( answer(-1, *task)[1:] + 100, sample[1:])
def _test_multiple_stateful_sources_single_worker(num_workers): groups = [ MockGroup.from_callback(IteratorCb(), batch=True), MockGroup.from_callback(IteratorCb(), batch=True) ] with create_pool(groups, keep_alive_queue_size=1, num_workers=num_workers, start_method="spawn") as pool: pids = get_pids(pool) assert len(pids) == min(num_workers, len(groups)) pool.schedule_batch(context_i=0, work_batch=TaskArgs.make_batch((0, ))) pool.schedule_batch(context_i=1, work_batch=TaskArgs.make_batch((0, ))) iter_worker_num_0 = pool.contexts[0].dedicated_worker_id iter_worker_num_1 = pool.contexts[1].dedicated_worker_id iter_worker_pid_0 = pool.pool._processes[iter_worker_num_0].pid iter_worker_pid_1 = pool.pool._processes[iter_worker_num_1].pid batch_0 = pool.receive_batch(context_i=0) batch_1 = pool.receive_batch(context_i=1) np.testing.assert_array_equal(np.array([iter_worker_pid_0, 1]), batch_0[0]) np.testing.assert_array_equal(np.array([iter_worker_pid_1, 1]), batch_1[0]) if num_workers == 1: assert iter_worker_pid_0 == iter_worker_pid_1 else: assert iter_worker_pid_0 != iter_worker_pid_1
def schedule_batch(self, pool, context_i, lead, batch_size, epoch_idx): """Schedule computing new batch from source callback by the parallel pool.""" if self.batch: return pool.schedule_batch(context_i, _TaskArgs.make_batch( self.callback_args(None, epoch_idx, lead=lead))) else: sample_range_start = self.current_sample + batch_size * lead sample_range_end = sample_range_start + batch_size iteration = self.current_iter + lead sample_range = _SampleRange(sample_range_start, sample_range_end, iteration, epoch_idx) work_batch = _TaskArgs.make_sample(sample_range) return pool.schedule_batch(context_i, work_batch)
def test_pool_no_overwrite_batch(start_method): groups = [MockGroup.from_callback(simple_callback, prefetch_queue_depth=0)] for depth in [1, 2, 4, 8]: with create_pool(groups, keep_alive_queue_size=depth, num_workers=1, start_method=start_method) as pool: pids = get_pids(pool) pid = pids[0] work_batches = [ TaskArgs.make_sample(SampleRange(i, i + 1, i, 0)) for i in range(depth) ] task_list = [[(SampleInfo(i, 0, i, 0), )] for i in range(depth)] for i, work_batch in enumerate(work_batches): pool.schedule_batch(context_i=0, work_batch=work_batch) assert_scheduled_num(pool.contexts[0], depth) batches = [] for i in range(depth): batches.append(pool.receive_batch(context_i=0)) assert_scheduled_num(pool.contexts[0], depth - 1 - i) tasks_batches = zip(task_list, batches) for tasks, batch in tasks_batches: for task, sample in zip(tasks, batch): np.testing.assert_array_equal(answer(pid, *task), sample)
def _split_work(self, work_batch: TaskArgs): if not work_batch.is_sample_mode(): return [work_batch] num_minibatches = self.pool.num_workers sample_range = work_batch.sample_range samples_num = len(sample_range) chunk_size = samples_num // num_minibatches remainder = samples_num % num_minibatches queued_no = 0 minibatches = [] for minibatch_i in range(num_minibatches): worker_chunk = chunk_size + (minibatch_i < remainder) if worker_chunk == 0: break sample_slice = sample_range[queued_no:queued_no + worker_chunk] minibatch = TaskArgs(minibatch_i, sample_range=sample_slice) minibatches.append(minibatch) queued_no += worker_chunk return minibatches
def test_pool_invalid_return(): callbacks = [MockGroup.from_callback(invalid_callback)] with create_pool(callbacks, keep_alive_queue_size=1, num_workers=1, start_method="spawn") as pool: _ = get_pids(pool) work_batch = TaskArgs.make_sample(SampleRange(0, 1, 0, 0)) pool.schedule_batch(context_i=0, work_batch=work_batch) pool.receive_batch(context_i=0)
def test_pool_iterator_dedicated_worker(start_method): groups = [ MockGroup.from_callback(simple_callback, prefetch_queue_depth=3), MockGroup.from_callback(IteratorCb(), prefetch_queue_depth=3, batch=True) ] with create_pool(groups, keep_alive_queue_size=1, num_workers=4, start_method=start_method) as pool: pids = get_pids(pool) assert len(pids) == 4 tasks_list = [] samples_count = 0 for i in range(4): tasks = [(SampleInfo(samples_count + j, j, i, 0), ) for j in range(i + 1)] tasks_list.append(tasks) work_batch = TaskArgs.make_sample( SampleRange(samples_count, samples_count + i + 1, i, 0)) samples_count += len(tasks) pool.schedule_batch(context_i=0, work_batch=work_batch) pool.schedule_batch(context_i=1, work_batch=TaskArgs.make_batch((i, ))) assert pool.contexts[0].dedicated_worker_id is None iter_worker_num = pool.contexts[1].dedicated_worker_id iter_worker_pid = pool.pool._processes[iter_worker_num].pid for i in range(4): batch_0 = pool.receive_batch(context_i=0) batch_1 = pool.receive_batch(context_i=1) tasks = tasks_list[i] assert len(batch_0) == len(tasks) assert len(batch_1) == len(tasks) for task, sample in zip(tasks, batch_0): np.testing.assert_array_equal( answer(-1, *task)[1:], sample[1:]) for sample in batch_1: np.testing.assert_array_equal( np.array([iter_worker_pid, i + 1]), sample)
def test_pool_multi_task(start_method): groups = [MockGroup.from_callback(simple_callback)] with create_pool(groups, keep_alive_queue_size=1, num_workers=1, start_method=start_method) as pool: pids = get_pids(pool) pid = pids[0] tasks = [(SampleInfo(i, i, 0, 0), ) for i in range(10)] work_batch = TaskArgs.make_sample(SampleRange(0, 10, 0, 0)) pool.schedule_batch(context_i=0, work_batch=work_batch) batch = pool.receive_batch(context_i=0) for task, sample in zip(tasks, batch): np.testing.assert_array_equal(answer(pid, *task), sample)
def test_pool_work_split_multiple_tasks(start_method): callbacks = [MockGroup.from_callback(simple_callback)] with create_pool(callbacks, keep_alive_queue_size=1, num_workers=2, start_method=start_method) as pool: num_tasks = 16 pids = get_pids(pool) assert len(pids) == 2 work_batch = TaskArgs.make_sample(SampleRange(0, num_tasks, 0, 0)) tasks = [(SampleInfo(i, i, 0, 0), ) for i in range(num_tasks)] pool.schedule_batch(context_i=0, work_batch=work_batch) batch = pool.receive_batch(context_i=0) for task, sample in zip(tasks, batch): np.testing.assert_array_equal(answer(-1, *task)[1:], sample[1:])
def test_pool_many_ctxs(start_method): callbacks = [simple_callback, another_callback] groups = [MockGroup.from_callback(cb) for cb in callbacks] with create_pool(groups, keep_alive_queue_size=1, num_workers=1, start_method=start_method) as pool: pids = get_pids(pool) pid = pids[0] tasks = [(SampleInfo(0, 0, 0, 0), )] work_batch = TaskArgs.make_sample(SampleRange(0, 1, 0, 0)) pool.schedule_batch(context_i=0, work_batch=work_batch) pool.schedule_batch(context_i=1, work_batch=work_batch) batch_0 = pool.receive_batch(context_i=0) batch_1 = pool.receive_batch(context_i=1) for task, sample, pid in zip(tasks, batch_0, pids): np.testing.assert_array_equal(answer(pid, *task), sample) for task, sample, pid in zip(tasks, batch_1, pids): np.testing.assert_array_equal(answer(pid, *task) + 100, sample)