def test_queue_full_assertion(): for start_method in ("spawn", "fork"): for capacity in [1, 4]: for one_by_one in (True, False): mp = multiprocessing.get_context(start_method) queue = ShmQueue(mp, capacity) msgs = [ ShmMessageDesc(i, i, i, i, i) for i in range(capacity + 1) ] yield raises( RuntimeError, "The queue is full")(_put_msgs), queue, msgs, one_by_one
def from_contexts(cls, contexts: List[CallbackContext], num_workers, start_method="fork", py_callback_pickler=None): mp = multiprocessing.get_context(start_method) # checks if there are any sources without dedicated worker id, if so, the `general_task_queue` # instance is needed to distribute tasks among all the workers general_sources_buffs = [ context.shm_manager for context in contexts if context.dedicated_worker_id is None ] if not general_sources_buffs: general_task_queue = None else: # Each scheduled task has a shm chunk assigned for results, the number of # scheduled tasks won't exceed the number of chunks available for results general_task_queue = ShmQueue( mp, capacity=sum(shm_manager.num_chunks for shm_manager in general_sources_buffs)) # Each computed minibatch makes for one message in the results queue, the number of # messages won't exceed the number of shm chunks available to store the minibatches # in all the `ShmChunkManager` instances. scheduled_tasks_upper_bound = sum(context.shm_manager.num_chunks for context in contexts) # assure enough space for messages sent to confirm initialization of the workers result_queue_capacity = max(scheduled_tasks_upper_bound, num_workers) result_queue = ShmQueue(mp, capacity=result_queue_capacity) callback_pickler = None if start_method == "fork" else pickling._CustomPickler.create( py_callback_pickler) worker_contexts = create_worker_contexts(mp, contexts, num_workers, callback_pickler) instance = None try: instance = cls(mp, worker_contexts, result_queue, general_task_queue, callback_pickler) if general_task_queue is not None: general_task_queue.close_handle() result_queue.close_handle() for worker_context in worker_contexts: if worker_context.dedicated_task_queue is not None: worker_context.dedicated_task_queue.close_handle() return instance except: if instance is not None: instance.close() raise
def create_worker_contexts(mp, callback_contexts: List[CallbackContext], num_workers, callback_pickler) -> List[WorkerContext]: """ Prepares list of `WorkerContext` instances. Each instance describes parameters specific to a given worker process (as opposed to parameters common for all processes in the pool). WorkerContext contains sources that the worker will receive and shared memory chunks corresponding to the sources. It also contains dedicated `ShmQueue` instance if any of the sources was assigned a dedicated worker. """ if callback_pickler is None: source_descs = [ cb_context.source_desc for cb_context in callback_contexts ] else: source_descs = [ copy.copy(cb_context.source_desc) for cb_context in callback_contexts ] for source_desc in source_descs: source_desc.source = callback_pickler.dumps(source_desc.source) general_cb_contexts = [ i for i, cb_context in enumerate(callback_contexts) if cb_context.dedicated_worker_id is None ] worker_contexts = [] for worker_id in range(num_workers): dedicated_cb_contexts = [ i for i, cb_context in enumerate(callback_contexts) if cb_context.dedicated_worker_id == worker_id ] worker_cb_contexts = general_cb_contexts + dedicated_cb_contexts worker_sources = {i: source_descs[i] for i in worker_cb_contexts} worker_shm_chunks = [ shm_chunk for i in worker_cb_contexts for shm_chunk in callback_contexts[i].shm_manager.get_chunks() ] if not dedicated_cb_contexts: dedicated_task_queue = None else: # Each scheduled task has a shm chunk assigned for results, the number of # scheduled tasks won't exceed the number of chunks available for results dedicated_task_queue = ShmQueue( mp, capacity=sum(callback_contexts[i].shm_manager.num_chunks for i in dedicated_cb_contexts)) worker_context = WorkerContext(worker_sources, dedicated_task_queue, worker_shm_chunks) worker_contexts.append(worker_context) return worker_contexts