def Dequeue(input_queue: queue.Queue, check=lambda: True) -> LocalIterator[SampleBatchType]: """Dequeue data items from a queue.Queue instance. The dequeue is non-blocking, so Dequeue operations can executed with Enqueue via the Concurrently() operator. Args: input_queue (Queue): queue to pull items from. check (fn): liveness check. When this function returns false, Dequeue() will raise an error to halt execution. Examples: >>> queue = queue.Queue(100) >>> write_op = ParallelRollouts(...).for_each(Enqueue(queue)) >>> read_op = Dequeue(queue) >>> combined_op = Concurrently([write_op, read_op], mode="async") >>> next(combined_op) SampleBatch(...) """ if not isinstance(input_queue, queue.Queue): raise ValueError("Expected queue.Queue, got {}".format( type(input_queue))) def base_iterator(timeout=None): while check(): try: item = input_queue.get_nowait() yield item except queue.Empty: yield _NextValueNotReady() raise RuntimeError("Error raised reading from queue") return LocalIterator(base_iterator, SharedMetrics())
def get_shard(self, shard_index: int) -> "LocalIterator[T]": """Return a local iterator for the given shard. The iterator is guaranteed to be serializable and can be passed to remote tasks or actors. """ a, t = None, None i = shard_index for actor_set in self.actor_sets: if i < len(actor_set.actors): a = actor_set.actors[i] t = actor_set.transforms break else: i -= len(actor_set.actors) if a is None: raise ValueError("Shard index out of range", shard_index, self.num_shards()) def base_iterator(timeout=None): ray.get(a.par_iter_init.remote(t)) while True: try: yield ray.get(a.par_iter_next.remote(), timeout=timeout) # Always yield after each round of gets with timeout. if timeout is not None: yield _NextValueNotReady() except TimeoutError: yield _NextValueNotReady() except StopIteration: break name = self.name + ".shard[{}]".format(shard_index) return LocalIterator(base_iterator, SharedMetrics(), name=name)
def par_iter_init(self, transforms): """Implements ParallelIterator worker init.""" it = LocalIterator(lambda timeout: self.item_generator, SharedMetrics()) for fn in transforms: it = fn(it) assert it is not None, fn self.local_it = iter(it)
def union(self, *others: "LocalIterator[T]", deterministic: bool = False) -> "LocalIterator[T]": """Return an iterator that is the union of this and the others. If deterministic=True, we alternate between reading from one iterator and the others. Otherwise we return items from iterators as they become ready. """ for it in others: if not isinstance(it, LocalIterator): raise ValueError( "other must be of type LocalIterator, got {}".format( type(it))) if deterministic: timeout = None else: timeout = 0 active = [] parent_iters = [self] + list(others) shared_metrics = SharedMetrics( parents=[p.shared_metrics for p in parent_iters]) for it in parent_iters: active.append( LocalIterator(it.base_iterator, shared_metrics, it.local_transforms, timeout=timeout)) def build_union(timeout=None): while True: for it in list(active): # Yield items from the iterator until _NextValueNotReady is # found, then switch to the next iterator. # To avoid starvation, we yield at most max_yield items per # iterator before switching. if deterministic: max_yield = 1 # Forces round robin. else: max_yield = 20 try: for _ in range(max_yield): item = next(it) if isinstance(item, _NextValueNotReady): break else: yield item except StopIteration: active.remove(it) if not active: break return LocalIterator(build_union, shared_metrics, [], name="LocalUnion[{}, {}]".format( self, ", ".join(map(str, others))))
def LocalReplay(replay_buffer: ReplayBuffer, train_batch_size: int): """Replay experiences from a local buffer instance. This should be combined with the StoreToReplayBuffer operation using the Concurrently() operator. Arguments: replay_buffer (ReplayBuffer): Buffer to replay experiences from. train_batch_size (int): Batch size of fetches from the buffer. Examples: >>> actors = [ReplayActor.remote() for _ in range(4)] >>> replay_op = ParallelReplay(actors) >>> next(replay_op) SampleBatch(...) """ assert isinstance(replay_buffer, ReplayBuffer) replay_buffers = {DEFAULT_POLICY_ID: replay_buffer} # TODO(ekl) support more options, or combine with ParallelReplay (?) synchronize_sampling = False prioritized_replay_beta = None def gen_replay(timeout): while True: samples = {} idxes = None for policy_id, replay_buffer in replay_buffers.items(): if synchronize_sampling: if idxes is None: idxes = replay_buffer.sample_idxes(train_batch_size) else: idxes = replay_buffer.sample_idxes(train_batch_size) if isinstance(replay_buffer, PrioritizedReplayBuffer): metrics = LocalIterator.get_metrics() num_steps_trained = metrics.counters[STEPS_TRAINED_COUNTER] (obses_t, actions, rewards, obses_tp1, dones, weights, batch_indexes) = replay_buffer.sample_with_idxes( idxes, beta=prioritized_replay_beta.value(num_steps_trained)) else: (obses_t, actions, rewards, obses_tp1, dones) = replay_buffer.sample_with_idxes(idxes) weights = np.ones_like(rewards) batch_indexes = -np.ones_like(rewards) samples[policy_id] = SampleBatch({ "obs": obses_t, "actions": actions, "rewards": rewards, "new_obs": obses_tp1, "dones": dones, "weights": weights, "batch_indexes": batch_indexes }) yield MultiAgentBatch(samples, train_batch_size) return LocalIterator(gen_replay, SharedMetrics())
def get_shard(self, shard_index: int, batch_ms: int = 0, num_async: int = 1) -> "LocalIterator[T]": """Return a local iterator for the given shard. The iterator is guaranteed to be serializable and can be passed to remote tasks or actors. Arguments: shard_index (int): Index of the shard to gather. batch_ms (int): Batches items for batch_ms milliseconds before retrieving it. Increasing batch_ms increases latency but improves throughput. If this value is 0, then items are returned immediately. num_async (int): The max number of requests in flight. Increasing this improves the amount of pipeline parallelism in the iterator. """ if num_async < 1: raise ValueError("num async must be positive") if batch_ms < 0: raise ValueError("batch time must be positive") a, t = None, None i = shard_index for actor_set in self.actor_sets: if i < len(actor_set.actors): a = actor_set.actors[i] t = actor_set.transforms break else: i -= len(actor_set.actors) if a is None: raise ValueError("Shard index out of range", shard_index, self.num_shards()) def base_iterator(timeout=None): queue = collections.deque() ray.get(a.par_iter_init.remote(t)) for _ in range(num_async): queue.append(a.par_iter_next_batch.remote(batch_ms)) while True: try: batch = ray.get(queue.popleft(), timeout=timeout) queue.append(a.par_iter_next_batch.remote(batch_ms)) for item in batch: yield item # Always yield after each round of gets with timeout. if timeout is not None: yield _NextValueNotReady() except TimeoutError: yield _NextValueNotReady() except StopIteration: break name = self.name + f".shard[{shard_index}]" return LocalIterator(base_iterator, SharedMetrics(), name=name)
def execution_plan(workers: WorkerSet, config: TrainerConfigDict) -> LocalIterator[dict]: """Execution plan of the DQN algorithm. Defines the distributed dataflow. Args: workers (WorkerSet): The WorkerSet for training the Polic(y/ies) of the Trainer. config (TrainerConfigDict): The trainer's configuration dict. Returns: LocalIterator[dict]: A local iterator over training metrics. """ replay_buffer_actor = ReservoirReplayActor.remote( num_shards=1, learning_starts=config["learning_starts"], buffer_size=config["buffer_size"], replay_batch_size=config["train_batch_size"], replay_mode=config["multiagent"]["replay_mode"], replay_sequence_length=config["replay_sequence_length"], ) # Store a handle for the replay buffer actor in the local worker workers.local_worker().replay_buffer_actor = replay_buffer_actor # Read and train on experiences from the replay buffer. Every batch # returned from the Replay iterator is passed to TrainOneStep to # take a SGD step. post_fn = config.get("before_learn_on_batch") or (lambda b, *a: b) print("running replay op..") def gen_replay(_): while True: item = ray.get(replay_buffer_actor.replay.remote()) if item is None: yield _NextValueNotReady() else: yield item replay_op = LocalIterator(gen_replay, SharedMetrics()) \ .for_each(lambda x: post_fn(x, workers, config)) \ .for_each(TrainOneStep(workers)) replay_op = StandardMetricsReporting(replay_op, workers, config) replay_op = map( lambda x: x if not isinstance(x, _NextValueNotReady) else {}, replay_op) return replay_op
def Replay( *, local_buffer: Optional[MultiAgentReplayBuffer] = None, num_items_to_replay: int = 1, actors: Optional[List[ActorHandle]] = None, num_async: int = 4, ) -> LocalIterator[SampleBatchType]: """Replay experiences from the given buffer or actors. This should be combined with the StoreToReplayActors operation using the Concurrently() operator. Args: local_buffer: Local buffer to use. Only one of this and replay_actors can be specified. num_items_to_replay: Number of items to sample from buffer actors: List of replay actors. Only one of this and local_buffer can be specified. num_async: In async mode, the max number of async requests in flight per actor. Examples: >>> from ray.rllib.utils.replay_buffers import multi_agent_replay_buffer >>> actors = [ # doctest: +SKIP ... multi_agent_replay_buffer.ReplayActor.remote() for _ in range(4)] >>> replay_op = Replay(actors=actors, # doctest: +SKIP ... num_items_to_replay=batch_size) >>> next(replay_op) # doctest: +SKIP SampleBatch(...) """ if local_buffer is not None and actors is not None: raise ValueError( "Exactly one of local_buffer and replay_actors must be given.") if actors is not None: for actor in actors: actor.make_iterator.remote(num_items_to_replay=num_items_to_replay) replay = from_actors(actors) return replay.gather_async( num_async=num_async).filter(lambda x: x is not None) def gen_replay(_): while True: item = local_buffer.sample(num_items_to_replay) if item is None: yield _NextValueNotReady() else: yield item return LocalIterator(gen_replay, SharedMetrics())
def union(self, *others: "LocalIterator[T]", deterministic: bool = False) -> "LocalIterator[T]": """Return an iterator that is the union of this and the others. If deterministic=True, we alternate between reading from one iterator and the others. Otherwise we return items from iterators as they become ready. """ for it in others: if not isinstance(it, LocalIterator): raise ValueError( "other must be of type LocalIterator, got {}".format( type(it))) timeout = None if deterministic else 0 active = [] parent_iters = [self] + list(others) shared_metrics = SharedMetrics( parents=[p.shared_metrics for p in parent_iters]) for it in parent_iters: active.append( LocalIterator( it.base_iterator, shared_metrics, it.local_transforms, timeout=timeout)) def build_union(timeout=None): while True: for it in list(active): try: item = next(it) if isinstance(item, _NextValueNotReady): if timeout is not None: yield item else: yield item except StopIteration: active.remove(it) if not active: break return LocalIterator( build_union, shared_metrics, [], name="LocalUnion[{}, {}]".format(self, ", ".join(map(str, others))))
def Replay(*, local_buffer: LocalReplayBuffer = None, actors: List["ActorHandle"] = None, num_async=4): """Replay experiences from the given buffer or actors. This should be combined with the StoreToReplayActors operation using the Concurrently() operator. Arguments: local_buffer (LocalReplayBuffer): Local buffer to use. Only one of this and replay_actors can be specified. actors (list): List of replay actors. Only one of this and local_buffer can be specified. num_async (int): In async mode, the max number of async requests in flight per actor. Examples: >>> actors = [ReplayActor.remote() for _ in range(4)] >>> replay_op = Replay(actors=actors) >>> next(replay_op) SampleBatch(...) Code From: https://github.com/ray-project/ray/blob/ray-0.8.7/rllib/execution/replay_ops.py """ if bool(local_buffer) == bool(actors): raise ValueError( "Exactly one of local_buffer and replay_actors must be given.") if actors: replay = from_actors(actors) return replay.gather_async( num_async=num_async).filter(lambda x: x is not None) def gen_replay(_): while True: item = local_buffer.replay() if item is None: yield _NextValueNotReady() else: yield item return LocalIterator(gen_replay, SharedMetrics())
def Replay(*, local_buffer: MultiAgentReplayBuffer = None, actors: List[ActorHandle] = None, num_async: int = 4) -> LocalIterator[SampleBatchType]: """Replay experiences from the given buffer or actors. This should be combined with the StoreToReplayActors operation using the Concurrently() operator. Args: local_buffer: Local buffer to use. Only one of this and replay_actors can be specified. actors: List of replay actors. Only one of this and local_buffer can be specified. num_async: In async mode, the max number of async requests in flight per actor. Examples: >>> actors = [ReplayActor.remote() for _ in range(4)] >>> replay_op = Replay(actors=actors) >>> next(replay_op) SampleBatch(...) """ if bool(local_buffer) == bool(actors): raise ValueError( "Exactly one of local_buffer and replay_actors must be given.") if actors: replay = from_actors(actors) return replay.gather_async( num_async=num_async).filter(lambda x: x is not None) def gen_replay(_): while True: item = local_buffer.replay() if item is None: yield _NextValueNotReady() else: yield item return LocalIterator(gen_replay, SharedMetrics())
def Dequeue(input_queue: queue.Queue, check=lambda: True) -> LocalIterator[SampleBatchType]: """Dequeue data items from a queue.Queue instance. The dequeue is non-blocking, so Dequeue operations can execute with Enqueue via the Concurrently() operator. Args: input_queue: queue to pull items from. check: liveness check. When this function returns false, Dequeue() will raise an error to halt execution. Examples: >>> import queue >>> from ray.rllib.execution import ParallelRollouts >>> queue = queue.Queue(100) # doctest: +SKIP >>> write_op = ParallelRollouts(...) # doctest: +SKIP ... .for_each(Enqueue(queue)) >>> read_op = Dequeue(queue) # doctest: +SKIP >>> combined_op = Concurrently( # doctest: +SKIP ... [write_op, read_op], mode="async") >>> next(combined_op) # doctest: +SKIP SampleBatch(...) """ if not isinstance(input_queue, queue.Queue): raise ValueError("Expected queue.Queue, got {}".format( type(input_queue))) def base_iterator(timeout=None): while check(): try: item = input_queue.get(timeout=0.001) yield item except queue.Empty: yield _NextValueNotReady() raise RuntimeError("Dequeue `check()` returned False! " "Exiting with Exception from Dequeue iterator.") return LocalIterator(base_iterator, SharedMetrics())
def LocalReservoirMultiagent(reservoir_buffers, train_batch_size, min_size_to_learn, learn_every): """ Get experiences from multi-agent reservoir buffer Arguments: reservoir_buffers (MultiAgentReservoirBuffer): Buffers to replay experiences from. train_batch_size (int): Batch size of fetches from the buffer. min_size_to_learn (int): Minimum buffer length to start learning. learn_every (int): Number of steps between any learning iteration. """ assert isinstance(reservoir_buffers, MultiAgentReservoirBuffer) def gen_replay(timeout): while True: samples = {} idxes = None for policy_id, reservoir_buffer in reservoir_buffers.buffers.items( ): if len(reservoir_buffer) >= min_size_to_learn and \ reservoir_buffers.steps[policy_id] >= learn_every: idxes = reservoir_buffer.sample_idxes(train_batch_size) (obses_t, actions) = reservoir_buffer.sample_with_idxes(idxes) samples[policy_id] = SampleBatch({ "obs": obses_t, "actions": actions, }) reservoir_buffers.steps[policy_id] = 0 if samples == {}: yield _NextValueNotReady() else: yield MultiAgentBatch(samples, train_batch_size) return LocalIterator(gen_replay, SharedMetrics())
def batch_across_shards(self) -> "LocalIterator[List[T]]": """Iterate over the results of multiple shards in parallel. Examples: >>> it = from_iterators([range(3), range(3)]) >>> next(it.batch_across_shards()) ... [0, 0] """ def base_iterator(timeout=None): active = [] for actor_set in self.actor_sets: actor_set.init_actors() active.extend(actor_set.actors) futures = [a.par_iter_next.remote() for a in active] while active: try: yield ray.get(futures, timeout=timeout) futures = [a.par_iter_next.remote() for a in active] # Always yield after each round of gets with timeout. if timeout is not None: yield _NextValueNotReady() except TimeoutError: yield _NextValueNotReady() except StopIteration: # Find and remove the actor that produced StopIteration. results = [] for a, f in zip(list(active), futures): try: results.append(ray.get(f)) except StopIteration: active.remove(a) if results: yield results futures = [a.par_iter_next.remote() for a in active] name = "{}.batch_across_shards()".format(self) return LocalIterator(base_iterator, SharedMetrics(), name=name)
def SimpleLocalReplayMultiagent(replay_buffers, train_batch_size, min_size_to_learn, learn_every): """Replay experiences from a MultiAgentReplayBuffer instance. Arguments: replay_buffers (MultiAgentSimpleReplayBuffer): Buffers to replay experiences from. train_batch_size (int): Batch size of fetches from the buffer. min_size_to_learn (int): Minimum buffer length to start learning. learn_every (int): Number of steps between any learning iteration. """ assert isinstance(replay_buffers, MultiAgentSimpleReplayBuffer) def gen_replay(timeout): while True: samples = {} for policy_id, replay_buffer in replay_buffers.buffers.items(): if len(replay_buffer.replay_batches) >= min_size_to_learn and \ replay_buffers.steps[policy_id] >= learn_every: batch = None for x in range(train_batch_size): if batch is None: batch = replay_buffer.replay( ).decompress_if_needed() else: batch = batch.concat( replay_buffer.replay().decompress_if_needed()) replay_buffers.steps[policy_id] = 0 samples[policy_id] = batch if samples == {}: yield _NextValueNotReady() else: yield MultiAgentBatch(samples, train_batch_size) return LocalIterator(gen_replay, SharedMetrics())
def ParallelRollouts(workers: WorkerSet, *, mode="bulk_sync", num_async=1) -> LocalIterator[SampleBatch]: """Operator to collect experiences in parallel from rollout workers. If there are no remote workers, experiences will be collected serially from the local worker instance instead. Arguments: workers (WorkerSet): set of rollout workers to use. mode (str): One of {'async', 'bulk_sync', 'raw'}. - In 'async' mode, batches are returned as soon as they are computed by rollout workers with no order guarantees. - In 'bulk_sync' mode, we collect one batch from each worker and concatenate them together into a large batch to return. - In 'raw' mode, the ParallelIterator object is returned directly and the caller is responsible for implementing gather and updating the timesteps counter. num_async (int): In async mode, the max number of async requests in flight per actor. Returns: A local iterator over experiences collected in parallel. Examples: >>> rollouts = ParallelRollouts(workers, mode="async") >>> batch = next(rollouts) >>> print(batch.count) 50 # config.rollout_fragment_length >>> rollouts = ParallelRollouts(workers, mode="bulk_sync") >>> batch = next(rollouts) >>> print(batch.count) 200 # config.rollout_fragment_length * config.num_workers Updates the STEPS_SAMPLED_COUNTER counter in the local iterator context. """ # Ensure workers are initially in sync. workers.sync_weights() def report_timesteps(batch): metrics = _get_shared_metrics() metrics.counters[STEPS_SAMPLED_COUNTER] += batch.count return batch if not workers.remote_workers(): # Handle the serial sampling case. def sampler(_): while True: yield workers.local_worker().sample() return (LocalIterator(sampler, SharedMetrics()).for_each(report_timesteps)) # Create a parallel iterator over generated experiences. rollouts = from_actors(workers.remote_workers()) if mode == "bulk_sync": return rollouts \ .batch_across_shards() \ .for_each(lambda batches: SampleBatch.concat_samples(batches)) \ .for_each(report_timesteps) elif mode == "async": return rollouts.gather_async( num_async=num_async).for_each(report_timesteps) elif mode == "raw": return rollouts else: raise ValueError("mode must be one of 'bulk_sync', 'async', 'raw', " "got '{}'".format(mode))
def iter_list(values): return LocalIterator(lambda _: values, SharedMetrics())
def union(self, *others: "LocalIterator[T]", deterministic: bool = False, round_robin_weights: List[float] = None) -> "LocalIterator[T]": """Return an iterator that is the union of this and the others. Args: deterministic (bool): If deterministic=True, we alternate between reading from one iterator and the others. Otherwise we return items from iterators as they become ready. round_robin_weights (list): List of weights to use for round robin mode. For example, [2, 1] will cause the iterator to pull twice as many items from the first iterator as the second. [2, 1, "*"] will cause as many items to be pulled as possible from the third iterator without blocking. This overrides the deterministic flag. """ for it in others: if not isinstance(it, LocalIterator): raise ValueError( "other must be of type LocalIterator, got {}".format( type(it))) active = [] parent_iters = [self] + list(others) shared_metrics = SharedMetrics( parents=[p.shared_metrics for p in parent_iters]) timeout = None if deterministic else 0 if round_robin_weights: if len(round_robin_weights) != len(parent_iters): raise ValueError( "Length of round robin weights must equal number of " "iterators total.") timeouts = [0 if w == "*" else None for w in round_robin_weights] else: timeouts = [timeout] * len(parent_iters) round_robin_weights = [1] * len(parent_iters) for i, it in enumerate(parent_iters): active.append( LocalIterator(it.base_iterator, shared_metrics, it.local_transforms, timeout=timeouts[i])) active = list(zip(round_robin_weights, active)) def build_union(timeout=None): while True: for weight, it in list(active): if weight == "*": max_pull = 100 # TOOD(ekl) how to best bound this? else: max_pull = _randomized_int_cast(weight) try: for _ in range(max_pull): item = next(it) if isinstance(item, _NextValueNotReady): if timeout is not None: yield item break else: yield item except StopIteration: active.remove((weight, it)) if not active: break return LocalIterator(build_union, shared_metrics, [], name="LocalUnion[{}, {}]".format( self, ", ".join(map(str, others))))
def LocalReplayMultiagent(replay_buffers, train_batch_size, min_size_to_learn, learn_every, learn_every_res=None, prioritized=False, beta=1): """Replay experiences from a MultiAgentReplayBuffer instance. Soon to be deprecated # TODO: update Arguments: replay_buffers (MultiAgentReplayBuffer): Buffers to replay experiences from. train_batch_size (int): Batch size of fetches from the buffer. min_size_to_learn (int): Minimum buffer length to start learning. learn_every (int): Number of steps between any learning iteration. learn_every_res (int): Number of steps between two learning iteration of the avg network prioritized (bool): DEPRECATED beta (float): DEPRECATED : """ assert isinstance(replay_buffers, MultiAgentReplayBuffer) def gen_replay(timeout): while True: samples = {} idxes = None for policy_id, replay_buffer in replay_buffers.buffers.items(): policy_multiplier = 2 if policy_id == 'policy_team' else 1 if len(replay_buffer) >= min_size_to_learn*policy_multiplier and \ replay_buffers.steps[policy_id] >= learn_every: idxes = replay_buffer.sample_idxes(train_batch_size) replay_buffers.steps[policy_id] = 0 if prioritized: (obses_t, actions, rewards, obses_tp1, dones, w, ind) \ = replay_buffer.sample_with_idxes(idxes, beta) else: (obses_t, actions, rewards, obses_tp1, dones) = replay_buffer.sample_with_idxes(idxes) weights = np.ones_like(rewards) batch_indexes = -np.ones_like(rewards) samples[policy_id] = SampleBatch({ "obs": obses_t, "actions": actions, "rewards": rewards, "new_obs": obses_tp1, "dones": dones, "weights": weights, "batch_indexes": batch_indexes }) if samples == {}: yield _NextValueNotReady() else: yield MultiAgentBatch(samples, train_batch_size) return LocalIterator(gen_replay, SharedMetrics())
def gather_async(self, num_async=1) -> "LocalIterator[T]": """Returns a local iterable for asynchronous iteration. New items will be fetched from the shards asynchronously as soon as the previous one is computed. Items arrive in non-deterministic order. Arguments: num_async (int): The max number of async requests in flight per actor. Increasing this improves the amount of pipeline parallelism in the iterator. Examples: >>> it = from_range(100, 1).gather_async() >>> next(it) ... 3 >>> next(it) ... 0 >>> next(it) ... 1 """ if num_async < 1: raise ValueError("queue depth must be positive") # Forward reference to the returned iterator. local_iter = None def base_iterator(timeout=None): all_actors = [] for actor_set in self.actor_sets: actor_set.init_actors() all_actors.extend(actor_set.actors) futures = {} for _ in range(num_async): for a in all_actors: futures[a.par_iter_next.remote()] = a while futures: pending = list(futures) if timeout is None: # First try to do a batch wait for efficiency. ready, _ = ray.wait(pending, num_returns=len(pending), timeout=0) # Fall back to a blocking wait. if not ready: ready, _ = ray.wait(pending, num_returns=1) else: ready, _ = ray.wait(pending, num_returns=len(pending), timeout=timeout) for obj_id in ready: actor = futures.pop(obj_id) try: local_iter.shared_metrics.get().current_actor = actor yield ray.get(obj_id) futures[actor.par_iter_next.remote()] = actor except StopIteration: pass # Always yield after each round of wait with timeout. if timeout is not None: yield _NextValueNotReady() name = "{}.gather_async()".format(self) local_iter = LocalIterator(base_iterator, SharedMetrics(), name=name) return local_iter
def ParallelRollouts(workers: WorkerSet, *, mode="bulk_sync", num_async=1) -> LocalIterator[SampleBatch]: """Operator to collect experiences in parallel from rollout workers. If there are no remote workers, experiences will be collected serially from the local worker instance instead. Args: workers (WorkerSet): set of rollout workers to use. mode (str): One of 'async', 'bulk_sync', 'raw'. In 'async' mode, batches are returned as soon as they are computed by rollout workers with no order guarantees. In 'bulk_sync' mode, we collect one batch from each worker and concatenate them together into a large batch to return. In 'raw' mode, the ParallelIterator object is returned directly and the caller is responsible for implementing gather and updating the timesteps counter. num_async (int): In async mode, the max number of async requests in flight per actor. Returns: A local iterator over experiences collected in parallel. Examples: >>> from ray.rllib.execution import ParallelRollouts >>> workers = ... # doctest: +SKIP >>> rollouts = ParallelRollouts(workers, mode="async") # doctest: +SKIP >>> batch = next(rollouts) # doctest: +SKIP >>> print(batch.count) # doctest: +SKIP 50 # config.rollout_fragment_length >>> rollouts = ParallelRollouts(workers, mode="bulk_sync") # doctest: +SKIP >>> batch = next(rollouts) # doctest: +SKIP >>> print(batch.count) # doctest: +SKIP 200 # config.rollout_fragment_length * config.num_workers Updates the STEPS_SAMPLED_COUNTER counter in the local iterator context. """ # Ensure workers are initially in sync. workers.sync_weights() def report_timesteps(batch): metrics = _get_shared_metrics() metrics.counters[STEPS_SAMPLED_COUNTER] += batch.count if isinstance(batch, MultiAgentBatch): metrics.counters[AGENT_STEPS_SAMPLED_COUNTER] += batch.agent_steps( ) else: metrics.counters[AGENT_STEPS_SAMPLED_COUNTER] += batch.count return batch if not workers.remote_workers(): # Handle the `num_workers=0` case, in which the local worker # has to do sampling as well. return LocalIterator( lambda timeout: workers.local_worker().item_generator, SharedMetrics()).for_each(report_timesteps) # Create a parallel iterator over generated experiences. rollouts = from_actors(workers.remote_workers()) if mode == "bulk_sync": return (rollouts.batch_across_shards().for_each( lambda batches: SampleBatch.concat_samples(batches)).for_each( report_timesteps)) elif mode == "async": return rollouts.gather_async( num_async=num_async).for_each(report_timesteps) elif mode == "raw": return rollouts else: raise ValueError( "mode must be one of 'bulk_sync', 'async', 'raw', got '{}'".format( mode))