Пример #1
0
    def worker_fn(
        worker_id: int,
        num_workers: int,
        dataset,
        batch_size: int,
        stack_fn: Callable,
        batch_queue: mp.Queue,
        terminate_event,
        exhausted_event,
    ):
        MPWorkerInfo.set_worker_info(
            num_workers=num_workers,
            worker_id=worker_id,
        )

        for batch in batcher(dataset, batch_size):
            stacked_batch = stack_fn(batch)
            try:
                if terminate_event.is_set():
                    return
                buf = io.BytesIO()
                ForkingPickler(buf, pickle.HIGHEST_PROTOCOL).dump(
                    (worker_id, stacked_batch))
                batch_queue.put(buf.getvalue())
            except (EOFError, BrokenPipeError):
                return

        exhausted_event.set()
Пример #2
0
def trim_encoded_sagemaker_parameters(encoded_params: dict,
                                      max_len: int = 256) -> dict:
    """
    Trim parameters that have already been encoded to a given max length.

    Example:

    >>> trim_encoded_sagemaker_parameters({
    ...     'foo': '[1, 2, 3]',
    ...     'bar': 'hello'
    ... }, max_len = 4)
    {'_0_foo': '[1, ',
     '_1_foo': '2, 3',
     '_2_foo': ']',
     '_0_bar': 'hell',
     '_1_bar': 'o'}
    """
    trimmed_params = {}
    for key, value in encoded_params.items():
        if len(value) > max_len:
            for idx, substr in enumerate(batcher(value, max_len)):
                trimmed_params[f"_{idx}_{key}"] = "".join(substr)
        else:
            trimmed_params[key] = value
    return trimmed_params
Пример #3
0
 def _predict_batch(
     self, dataset: Iterable[Dict], batch_size: int, **kwargs
 ) -> Iterator[SampleForecast]:
     for batch in batcher(dataset, batch_size):
         yield from (
             self._predict_batch_autoreg(batch, **kwargs)
             if self.auto_regression
             else self._predict_batch_one_shot(batch, **kwargs)
         )
Пример #4
0
    def __iter__(self):
        batch_iterator = (map(self.stack_fn,
                              batcher(self.data_iterable, self.batch_size))
                          if not self.num_workers else MultiProcessBatcher(
                              self.data_iterable,
                              batch_size=self.batch_size,
                              stack_fn=self.stack_fn,
                              decode_fn=self.decode_fn,
                              num_workers=self.num_workers,
                              max_queue_size=self.num_prefetch,
                          ))

        return batch_iterator
Пример #5
0
    def __init__(
        self,
        dataset: Dataset,
        *,
        transform: Transformation,
        batch_size: int,
        stack_fn: Callable,
        num_batches_per_epoch: int,
        num_workers: Optional[int] = None,
        num_prefetch: Optional[int] = None,
        shuffle_buffer_length: Optional[int] = None,
        decode_fn: Callable = lambda x: x,
    ) -> None:
        self.batch_size = batch_size
        self.stack_fn = stack_fn
        self.num_batches_per_epoch = num_batches_per_epoch
        self.num_workers = win32_guard(num_workers)
        self.num_prefetch = num_prefetch
        self.shuffle_buffer_length = shuffle_buffer_length

        if not self.num_workers:
            iterator = construct_training_iterator(
                dataset,
                transform=transform,
                shuffle_buffer_length=shuffle_buffer_length,
            )
            self.batch_iterator = map(stack_fn, batcher(iterator, batch_size))
        else:
            self.batch_iterator = MultiProcessBatcher(
                dataset,
                transform=transform,
                batch_size=batch_size,
                stack_fn=stack_fn,
                decode_fn=decode_fn,
                num_workers=self.num_workers,
                max_queue_size=num_prefetch,
                shuffle_buffer_length=shuffle_buffer_length,
            )
Пример #6
0
        constant_dataset()[1],
    ],
)
def test_pseudo_shuffled(data: Iterable) -> None:
    list_data = list(data)
    shuffled_iter = PseudoShuffled(iter(list_data), shuffle_buffer_length=5)
    shuffled_data = list(shuffled_iter)
    assert len(shuffled_data) == len(list_data)
    assert all(d in shuffled_data for d in list_data)


@pytest.mark.parametrize(
    "data, expected_elements_per_iteration",
    [
        (Cached(range(4)), (list(range(4)), ) * 5),
        (batcher(range(10), 3), ([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]], [])),
        (IterableSlice(range(10), 3), ([0, 1, 2], ) * 5),
        (
            IterableSlice(iter(range(10)), 3),
            ([0, 1, 2], [3, 4, 5], [6, 7, 8], [9], []),
        ),
        (
            IterableSlice(iter(Cyclic(range(5))), 3),
            ([0, 1, 2], [3, 4, 0], [1, 2, 3], [4, 0, 1]),
        ),
    ],
)
def test_iterate_multiple_times(data: Iterable,
                                expected_elements_per_iteration: Tuple[List]):
    for expected_elements in expected_elements_per_iteration:
        assert list(data) == expected_elements
Пример #7
0
 def __call__(self, data, is_train):
     yield from batcher(data, self.batch_size)
Пример #8
0
 def __iter__(self):
     yield from map(
         self.stack_fn,
         batcher(self.transformed_dataset, self.batch_size),
     )