예제 #1
0
파일: loader.py 프로젝트: yifeim/gluon-ts
    def worker_fn(
        worker_id: int,
        num_workers: int,
        dataset,
        batch_size: int,
        stack_fn: Callable,
        batch_queue: mp.Queue,
        terminate_event,
        exhausted_event,
    ):
        MPWorkerInfo.set_worker_info(
            num_workers=num_workers,
            worker_id=worker_id,
        )

        for batch in batcher(dataset, batch_size):
            stacked_batch = stack_fn(batch)
            try:
                if terminate_event.is_set():
                    return
                buf = io.BytesIO()
                ForkingPickler(buf, pickle.HIGHEST_PROTOCOL).dump(
                    (worker_id, stacked_batch))
                batch_queue.put(buf.getvalue())
            except (EOFError, BrokenPipeError):
                return

        exhausted_event.set()
예제 #2
0
def worker_fn(
    worker_id: int,
    dataset,
    num_workers: int,
    result_queue: mp.Queue,
):
    MPWorkerInfo.set_worker_info(
        num_workers=num_workers,
        worker_id=worker_id,
    )

    for raw in map(_encode, dataset):
        try:
            result_queue.put(raw)
        except (EOFError, BrokenPipeError):
            return
예제 #3
0
def _worker_initializer(
    dataset: Dataset,
    transformation: Transformation,
    num_workers: int,
    worker_id_queue: Queue,
) -> None:
    """Initialier for processing pool."""

    _WorkerData.dataset = dataset
    _WorkerData.transformation = transformation

    # get unique worker id
    worker_id = int(worker_id_queue.get())
    multiprocessing.current_process().name = f"worker_{worker_id}"

    # propagate worker information
    MPWorkerInfo.set_worker_info(num_workers=num_workers,
                                 worker_id=worker_id,
                                 worker_process=True)
예제 #4
0
def worker_fn(
    worker_id: int,
    dataset,
    num_workers: int,
    input_queue: mp.Queue,
    output_queue: mp.Queue,
):
    MPWorkerInfo.set_worker_info(
        num_workers=num_workers,
        worker_id=worker_id,
    )

    while True:
        try:
            input_queue.get()
            for encoded_entry in map(_encode, dataset):
                output_queue.put(encoded_entry)
            output_queue.put(_encode(None))
        except (EOFError, BrokenPipeError):
            return