Пример #1
0
def Dequeue(input_queue: queue.Queue, check=lambda: True):
    """Dequeue data items from a queue.Queue instance.

    The dequeue is non-blocking, so Dequeue operations can executed with
    Enqueue via the Concurrently() operator.

    Arguments:
        input_queue (Queue): queue to pull items from.
        check (fn): liveness check. When this function returns false,
            Dequeue() will raise an error to halt execution.

    Examples:
        >>> queue = queue.Queue(100)
        >>> write_op = ParallelRollouts(...).for_each(Enqueue(queue))
        >>> read_op = Dequeue(queue)
        >>> combined_op = Concurrently([write_op, read_op], mode="async")
        >>> next(combined_op)
        SampleBatch(...)
    """
    if not isinstance(input_queue, queue.Queue):
        raise ValueError("Expected queue.Queue, got {}".format(
            type(input_queue)))

    def base_iterator(timeout=None):
        while check():
            try:
                item = input_queue.get_nowait()
                yield item
            except queue.Empty:
                yield _NextValueNotReady()
        raise RuntimeError("Error raised reading from queue")

    return LocalIterator(base_iterator, MetricsContext())
Пример #2
0
    def get_shard(self, shard_index: int) -> "LocalIterator[T]":
        """Return a local iterator for the given shard.

        The iterator is guaranteed to be serializable and can be passed to
        remote tasks or actors.
        """
        a, t = None, None
        i = shard_index
        for actor_set in self.actor_sets:
            if i < len(actor_set.actors):
                a = actor_set.actors[i]
                t = actor_set.transforms
                break
            else:
                i -= len(actor_set.actors)
        if a is None:
            raise ValueError("Shard index out of range", shard_index,
                             self.num_shards())

        def base_iterator(timeout=None):
            ray.get(a.par_iter_init.remote(t))
            while True:
                try:
                    yield ray.get(a.par_iter_next.remote(), timeout=timeout)
                    # Always yield after each round of gets with timeout.
                    if timeout is not None:
                        yield _NextValueNotReady()
                except TimeoutError:
                    yield _NextValueNotReady()
                except StopIteration:
                    break

        name = self.name + ".shard[{}]".format(shard_index)
        return LocalIterator(base_iterator, MetricsContext(), name=name)
Пример #3
0
    def union(self, other: "LocalIterator[T]",
              deterministic: bool = False) -> "LocalIterator[T]":
        """Return an iterator that is the union of this and the other.

        If deterministic=True, we alternate between reading from one iterator
        and the other. Otherwise we return items from iterators as they
        become ready.
        """

        if not isinstance(other, LocalIterator):
            raise ValueError(
                "other must be of type LocalIterator, got {}".format(
                    type(other)))

        if deterministic:
            timeout = None
        else:
            timeout = 0

        it1 = LocalIterator(
            self.base_iterator,
            self.metrics,
            self.local_transforms,
            timeout=timeout)
        it2 = LocalIterator(
            other.base_iterator,
            other.metrics,
            other.local_transforms,
            timeout=timeout)
        active = [it1, it2]

        def build_union(timeout=None):
            while True:
                for it in list(active):
                    # Yield items from the iterator until _NextValueNotReady is
                    # found, then switch to the next iterator.
                    try:
                        while True:
                            item = next(it)
                            if isinstance(item, _NextValueNotReady):
                                break
                            else:
                                yield item
                            if deterministic:
                                break
                    except StopIteration:
                        active.remove(it)
                if not active:
                    break

        # TODO(ekl) is this the best way to represent union() of metrics?
        new_ctx = MetricsContext()
        new_ctx.parent_metrics.append(self.metrics)
        new_ctx.parent_metrics.append(other.metrics)

        return LocalIterator(
            build_union,
            new_ctx, [],
            name="LocalUnion[{}, {}]".format(self, other))
Пример #4
0
 def par_iter_init(self, transforms):
     """Implements ParallelIterator worker init."""
     it = LocalIterator(lambda timeout: self.item_generator,
                        MetricsContext())
     for fn in transforms:
         it = fn(it)
         assert it is not None, fn
     self.local_it = iter(it)
Пример #5
0
def ParallelRollouts(workers: WorkerSet,
                     mode="bulk_sync") -> LocalIterator[SampleBatch]:
    """Operator to collect experiences in parallel from rollout workers.

    If there are no remote workers, experiences will be collected serially from
    the local worker instance instead.

    Arguments:
        workers (WorkerSet): set of rollout workers to use.
        mode (str): One of {'async', 'bulk_sync'}.
            - In 'async' mode, batches are returned as soon as they are
              computed by rollout workers with no order guarantees.
            - In 'bulk_sync' mode, we collect one batch from each worker
              and concatenate them together into a large batch to return.

    Returns:
        A local iterator over experiences collected in parallel.

    Examples:
        >>> rollouts = ParallelRollouts(workers, mode="async")
        >>> batch = next(rollouts)
        >>> print(batch.count)
        50  # config.sample_batch_size

        >>> rollouts = ParallelRollouts(workers, mode="bulk_sync")
        >>> batch = next(rollouts)
        >>> print(batch.count)
        200  # config.sample_batch_size * config.num_workers

    Updates the STEPS_SAMPLED_COUNTER counter in the local iterator context.
    """
    def report_timesteps(batch):
        metrics = LocalIterator.get_metrics()
        metrics.counters[STEPS_SAMPLED_COUNTER] += batch.count
        return batch

    if not workers.remote_workers():
        # Handle the serial sampling case.
        def sampler(_):
            while True:
                yield workers.local_worker().sample()

        return (LocalIterator(sampler,
                              MetricsContext()).for_each(report_timesteps))

    # Create a parallel iterator over generated experiences.
    rollouts = from_actors(workers.remote_workers())

    if mode == "bulk_sync":
        return rollouts \
            .batch_across_shards() \
            .for_each(lambda batches: SampleBatch.concat_samples(batches)) \
            .for_each(report_timesteps)
    elif mode == "async":
        return rollouts.gather_async().for_each(report_timesteps)
    else:
        raise ValueError(
            "mode must be one of 'bulk_sync', 'async', got '{}'".format(mode))
Пример #6
0
def LocalReplay(replay_buffer: ReplayBuffer, train_batch_size: int):
    """Replay experiences from a local buffer instance.

    This should be combined with the StoreToReplayBuffer operation using the
    Concurrently() operator.

    Arguments:
        replay_buffer (ReplayBuffer): Buffer to replay experiences from.
        train_batch_size (int): Batch size of fetches from the buffer.

    Examples:
        >>> actors = [ReplayActor.remote() for _ in range(4)]
        >>> replay_op = ParallelReplay(actors)
        >>> next(replay_op)
        SampleBatch(...)
    """
    assert isinstance(replay_buffer, ReplayBuffer)
    replay_buffers = {DEFAULT_POLICY_ID: replay_buffer}
    # TODO(ekl) support more options, or combine with ParallelReplay (?)
    synchronize_sampling = False
    prioritized_replay_beta = None

    def gen_replay(timeout):
        while True:
            samples = {}
            idxes = None
            for policy_id, replay_buffer in replay_buffers.items():
                if synchronize_sampling:
                    if idxes is None:
                        idxes = replay_buffer.sample_idxes(train_batch_size)
                else:
                    idxes = replay_buffer.sample_idxes(train_batch_size)

                if isinstance(replay_buffer, PrioritizedReplayBuffer):
                    metrics = LocalIterator.get_metrics()
                    num_steps_trained = metrics.counters[STEPS_TRAINED_COUNTER]
                    (obses_t, actions, rewards, obses_tp1, dones, weights,
                     batch_indexes) = replay_buffer.sample_with_idxes(
                         idxes,
                         beta=prioritized_replay_beta.value(num_steps_trained))
                else:
                    (obses_t, actions, rewards, obses_tp1,
                     dones) = replay_buffer.sample_with_idxes(idxes)
                    weights = np.ones_like(rewards)
                    batch_indexes = -np.ones_like(rewards)
                samples[policy_id] = SampleBatch({
                    "obs": obses_t,
                    "actions": actions,
                    "rewards": rewards,
                    "new_obs": obses_tp1,
                    "dones": dones,
                    "weights": weights,
                    "batch_indexes": batch_indexes
                })
            yield MultiAgentBatch(samples, train_batch_size)

    return LocalIterator(gen_replay, MetricsContext())
Пример #7
0
    def union(self,
              *others: "LocalIterator[T]",
              deterministic: bool = False) -> "LocalIterator[T]":
        """Return an iterator that is the union of this and the others.

        If deterministic=True, we alternate between reading from one iterator
        and the others. Otherwise we return items from iterators as they
        become ready.
        """

        for it in others:
            if not isinstance(it, LocalIterator):
                raise ValueError(
                    "other must be of type LocalIterator, got {}".format(
                        type(it)))

        if deterministic:
            timeout = None
        else:
            timeout = 0

        active = []
        shared_metrics = MetricsContext()
        for it in [self] + list(others):
            active.append(
                LocalIterator(it.base_iterator,
                              shared_metrics,
                              it.local_transforms,
                              timeout=timeout))

        def build_union(timeout=None):
            while True:
                for it in list(active):
                    # Yield items from the iterator until _NextValueNotReady is
                    # found, then switch to the next iterator.
                    # To avoid starvation, we yield at most max_yield items per
                    # iterator before switching.
                    if deterministic:
                        max_yield = 1  # Forces round robin.
                    else:
                        max_yield = 20
                    try:
                        for _ in range(max_yield):
                            item = next(it)
                            if isinstance(item, _NextValueNotReady):
                                break
                            else:
                                yield item
                    except StopIteration:
                        active.remove(it)
                if not active:
                    break

        return LocalIterator(build_union,
                             shared_metrics, [],
                             name="LocalUnion[{}, {}]".format(
                                 self, ", ".join(map(str, others))))
Пример #8
0
    def gather_async(self) -> "LocalIterator[T]":
        """Returns a local iterable for asynchronous iteration.

        New items will be fetched from the shards asynchronously as soon as
        the previous one is computed. Items arrive in non-deterministic order.

        Examples:
            >>> it = from_range(100, 1).gather_async()
            >>> next(it)
            ... 3
            >>> next(it)
            ... 0
            >>> next(it)
            ... 1
        """

        metrics = MetricsContext()

        def base_iterator(timeout=None):
            all_actors = []
            for actor_set in self.actor_sets:
                actor_set.init_actors()
                all_actors.extend(actor_set.actors)
            futures = {}
            for a in all_actors:
                futures[a.par_iter_next.remote()] = a
            while futures:
                pending = list(futures)
                if timeout is None:
                    # First try to do a batch wait for efficiency.
                    ready, _ = ray.wait(pending,
                                        num_returns=len(pending),
                                        timeout=0)
                    # Fall back to a blocking wait.
                    if not ready:
                        ready, _ = ray.wait(pending, num_returns=1)
                else:
                    ready, _ = ray.wait(pending,
                                        num_returns=len(pending),
                                        timeout=timeout)
                for obj_id in ready:
                    actor = futures.pop(obj_id)
                    try:
                        metrics.cur_actor = actor
                        yield ray.get(obj_id)
                        futures[actor.par_iter_next.remote()] = actor
                    except StopIteration:
                        pass
                # Always yield after each round of wait with timeout.
                if timeout is not None:
                    yield _NextValueNotReady()

        name = "{}.gather_async()".format(self)
        return LocalIterator(base_iterator, metrics, name=name)
Пример #9
0
def LocalReplay(replay_buffer, train_batch_size):
    replay_buffers = {DEFAULT_POLICY_ID: replay_buffer}
    # TODO(ekl) support more options
    synchronize_sampling = False
    prioritized_replay_beta = None

    def gen_replay(timeout):
        while True:
            samples = {}
            idxes = None
            for policy_id, replay_buffer in replay_buffers.items():
                if synchronize_sampling:
                    if idxes is None:
                        idxes = replay_buffer.sample_idxes(train_batch_size)
                else:
                    idxes = replay_buffer.sample_idxes(train_batch_size)

                if isinstance(replay_buffer, PrioritizedReplayBuffer):
                    metrics = LocalIterator.get_metrics()
                    num_steps_trained = metrics.counters[STEPS_TRAINED_COUNTER]
                    (obses_t, actions, rewards, obses_tp1, dones, weights,
                     batch_indexes) = replay_buffer.sample_with_idxes(
                         idxes,
                         beta=prioritized_replay_beta.value(num_steps_trained))
                else:
                    (obses_t, actions, rewards, obses_tp1,
                     dones) = replay_buffer.sample_with_idxes(idxes)
                    weights = np.ones_like(rewards)
                    batch_indexes = -np.ones_like(rewards)
                samples[policy_id] = SampleBatch({
                    "obs": obses_t,
                    "actions": actions,
                    "rewards": rewards,
                    "new_obs": obses_tp1,
                    "dones": dones,
                    "weights": weights,
                    "batch_indexes": batch_indexes
                })
            yield MultiAgentBatch(samples, train_batch_size)

    return LocalIterator(gen_replay, MetricsContext())
Пример #10
0
    def batch_across_shards(self) -> "LocalIterator[List[T]]":
        """Iterate over the results of multiple shards in parallel.

        Examples:
            >>> it = from_iterators([range(3), range(3)])
            >>> next(it.batch_across_shards())
            ... [0, 0]
        """

        def base_iterator(timeout=None):
            active = []
            for actor_set in self.actor_sets:
                actor_set.init_actors()
                active.extend(actor_set.actors)
            futures = [a.par_iter_next.remote() for a in active]
            while active:
                try:
                    yield ray.get(futures, timeout=timeout)
                    futures = [a.par_iter_next.remote() for a in active]
                    # Always yield after each round of gets with timeout.
                    if timeout is not None:
                        yield _NextValueNotReady()
                except TimeoutError:
                    yield _NextValueNotReady()
                except StopIteration:
                    # Find and remove the actor that produced StopIteration.
                    results = []
                    for a, f in zip(list(active), futures):
                        try:
                            results.append(ray.get(f))
                        except StopIteration:
                            active.remove(a)
                    if results:
                        yield results
                    futures = [a.par_iter_next.remote() for a in active]

        name = "{}.batch_across_shards()".format(self)
        return LocalIterator(base_iterator, MetricsContext(), name=name)
Пример #11
0
    def gather_async(self, async_queue_depth=1) -> "LocalIterator[T]":
        """Returns a local iterable for asynchronous iteration.

        New items will be fetched from the shards asynchronously as soon as
        the previous one is computed. Items arrive in non-deterministic order.

        Arguments:
            async_queue_depth (int): The max number of async requests in flight
                per actor. Increasing this improves the amount of pipeline
                parallelism in the iterator.

        Examples:
            >>> it = from_range(100, 1).gather_async()
            >>> next(it)
            ... 3
            >>> next(it)
            ... 0
            >>> next(it)
            ... 1
        """

        if async_queue_depth < 1:
            raise ValueError("queue depth must be positive")

        def base_iterator(timeout=None):
            metrics = LocalIterator.get_metrics()
            all_actors = []
            for actor_set in self.actor_sets:
                actor_set.init_actors()
                all_actors.extend(actor_set.actors)
            futures = {}
            for _ in range(async_queue_depth):
                for a in all_actors:
                    futures[a.par_iter_next.remote()] = a
            while futures:
                pending = list(futures)
                if timeout is None:
                    # First try to do a batch wait for efficiency.
                    ready, _ = ray.wait(pending,
                                        num_returns=len(pending),
                                        timeout=0)
                    # Fall back to a blocking wait.
                    if not ready:
                        ready, _ = ray.wait(pending, num_returns=1)
                else:
                    ready, _ = ray.wait(pending,
                                        num_returns=len(pending),
                                        timeout=timeout)
                for obj_id in ready:
                    actor = futures.pop(obj_id)
                    try:
                        metrics.current_actor = actor
                        yield ray.get(obj_id)
                        futures[actor.par_iter_next.remote()] = actor
                    except StopIteration:
                        pass
                # Always yield after each round of wait with timeout.
                if timeout is not None:
                    yield _NextValueNotReady()

        name = "{}.gather_async()".format(self)
        return LocalIterator(base_iterator, MetricsContext(), name=name)