def Dequeue(input_queue: queue.Queue, check=lambda: True): """Dequeue data items from a queue.Queue instance. The dequeue is non-blocking, so Dequeue operations can executed with Enqueue via the Concurrently() operator. Arguments: input_queue (Queue): queue to pull items from. check (fn): liveness check. When this function returns false, Dequeue() will raise an error to halt execution. Examples: >>> queue = queue.Queue(100) >>> write_op = ParallelRollouts(...).for_each(Enqueue(queue)) >>> read_op = Dequeue(queue) >>> combined_op = Concurrently([write_op, read_op], mode="async") >>> next(combined_op) SampleBatch(...) """ if not isinstance(input_queue, queue.Queue): raise ValueError("Expected queue.Queue, got {}".format( type(input_queue))) def base_iterator(timeout=None): while check(): try: item = input_queue.get_nowait() yield item except queue.Empty: yield _NextValueNotReady() raise RuntimeError("Error raised reading from queue") return LocalIterator(base_iterator, MetricsContext())
def get_shard(self, shard_index: int) -> "LocalIterator[T]": """Return a local iterator for the given shard. The iterator is guaranteed to be serializable and can be passed to remote tasks or actors. """ a, t = None, None i = shard_index for actor_set in self.actor_sets: if i < len(actor_set.actors): a = actor_set.actors[i] t = actor_set.transforms break else: i -= len(actor_set.actors) if a is None: raise ValueError("Shard index out of range", shard_index, self.num_shards()) def base_iterator(timeout=None): ray.get(a.par_iter_init.remote(t)) while True: try: yield ray.get(a.par_iter_next.remote(), timeout=timeout) # Always yield after each round of gets with timeout. if timeout is not None: yield _NextValueNotReady() except TimeoutError: yield _NextValueNotReady() except StopIteration: break name = self.name + ".shard[{}]".format(shard_index) return LocalIterator(base_iterator, MetricsContext(), name=name)
def union(self, other: "LocalIterator[T]", deterministic: bool = False) -> "LocalIterator[T]": """Return an iterator that is the union of this and the other. If deterministic=True, we alternate between reading from one iterator and the other. Otherwise we return items from iterators as they become ready. """ if not isinstance(other, LocalIterator): raise ValueError( "other must be of type LocalIterator, got {}".format( type(other))) if deterministic: timeout = None else: timeout = 0 it1 = LocalIterator( self.base_iterator, self.metrics, self.local_transforms, timeout=timeout) it2 = LocalIterator( other.base_iterator, other.metrics, other.local_transforms, timeout=timeout) active = [it1, it2] def build_union(timeout=None): while True: for it in list(active): # Yield items from the iterator until _NextValueNotReady is # found, then switch to the next iterator. try: while True: item = next(it) if isinstance(item, _NextValueNotReady): break else: yield item if deterministic: break except StopIteration: active.remove(it) if not active: break # TODO(ekl) is this the best way to represent union() of metrics? new_ctx = MetricsContext() new_ctx.parent_metrics.append(self.metrics) new_ctx.parent_metrics.append(other.metrics) return LocalIterator( build_union, new_ctx, [], name="LocalUnion[{}, {}]".format(self, other))
def par_iter_init(self, transforms): """Implements ParallelIterator worker init.""" it = LocalIterator(lambda timeout: self.item_generator, MetricsContext()) for fn in transforms: it = fn(it) assert it is not None, fn self.local_it = iter(it)
def ParallelRollouts(workers: WorkerSet, mode="bulk_sync") -> LocalIterator[SampleBatch]: """Operator to collect experiences in parallel from rollout workers. If there are no remote workers, experiences will be collected serially from the local worker instance instead. Arguments: workers (WorkerSet): set of rollout workers to use. mode (str): One of {'async', 'bulk_sync'}. - In 'async' mode, batches are returned as soon as they are computed by rollout workers with no order guarantees. - In 'bulk_sync' mode, we collect one batch from each worker and concatenate them together into a large batch to return. Returns: A local iterator over experiences collected in parallel. Examples: >>> rollouts = ParallelRollouts(workers, mode="async") >>> batch = next(rollouts) >>> print(batch.count) 50 # config.sample_batch_size >>> rollouts = ParallelRollouts(workers, mode="bulk_sync") >>> batch = next(rollouts) >>> print(batch.count) 200 # config.sample_batch_size * config.num_workers Updates the STEPS_SAMPLED_COUNTER counter in the local iterator context. """ def report_timesteps(batch): metrics = LocalIterator.get_metrics() metrics.counters[STEPS_SAMPLED_COUNTER] += batch.count return batch if not workers.remote_workers(): # Handle the serial sampling case. def sampler(_): while True: yield workers.local_worker().sample() return (LocalIterator(sampler, MetricsContext()).for_each(report_timesteps)) # Create a parallel iterator over generated experiences. rollouts = from_actors(workers.remote_workers()) if mode == "bulk_sync": return rollouts \ .batch_across_shards() \ .for_each(lambda batches: SampleBatch.concat_samples(batches)) \ .for_each(report_timesteps) elif mode == "async": return rollouts.gather_async().for_each(report_timesteps) else: raise ValueError( "mode must be one of 'bulk_sync', 'async', got '{}'".format(mode))
def LocalReplay(replay_buffer: ReplayBuffer, train_batch_size: int): """Replay experiences from a local buffer instance. This should be combined with the StoreToReplayBuffer operation using the Concurrently() operator. Arguments: replay_buffer (ReplayBuffer): Buffer to replay experiences from. train_batch_size (int): Batch size of fetches from the buffer. Examples: >>> actors = [ReplayActor.remote() for _ in range(4)] >>> replay_op = ParallelReplay(actors) >>> next(replay_op) SampleBatch(...) """ assert isinstance(replay_buffer, ReplayBuffer) replay_buffers = {DEFAULT_POLICY_ID: replay_buffer} # TODO(ekl) support more options, or combine with ParallelReplay (?) synchronize_sampling = False prioritized_replay_beta = None def gen_replay(timeout): while True: samples = {} idxes = None for policy_id, replay_buffer in replay_buffers.items(): if synchronize_sampling: if idxes is None: idxes = replay_buffer.sample_idxes(train_batch_size) else: idxes = replay_buffer.sample_idxes(train_batch_size) if isinstance(replay_buffer, PrioritizedReplayBuffer): metrics = LocalIterator.get_metrics() num_steps_trained = metrics.counters[STEPS_TRAINED_COUNTER] (obses_t, actions, rewards, obses_tp1, dones, weights, batch_indexes) = replay_buffer.sample_with_idxes( idxes, beta=prioritized_replay_beta.value(num_steps_trained)) else: (obses_t, actions, rewards, obses_tp1, dones) = replay_buffer.sample_with_idxes(idxes) weights = np.ones_like(rewards) batch_indexes = -np.ones_like(rewards) samples[policy_id] = SampleBatch({ "obs": obses_t, "actions": actions, "rewards": rewards, "new_obs": obses_tp1, "dones": dones, "weights": weights, "batch_indexes": batch_indexes }) yield MultiAgentBatch(samples, train_batch_size) return LocalIterator(gen_replay, MetricsContext())
def union(self, *others: "LocalIterator[T]", deterministic: bool = False) -> "LocalIterator[T]": """Return an iterator that is the union of this and the others. If deterministic=True, we alternate between reading from one iterator and the others. Otherwise we return items from iterators as they become ready. """ for it in others: if not isinstance(it, LocalIterator): raise ValueError( "other must be of type LocalIterator, got {}".format( type(it))) if deterministic: timeout = None else: timeout = 0 active = [] shared_metrics = MetricsContext() for it in [self] + list(others): active.append( LocalIterator(it.base_iterator, shared_metrics, it.local_transforms, timeout=timeout)) def build_union(timeout=None): while True: for it in list(active): # Yield items from the iterator until _NextValueNotReady is # found, then switch to the next iterator. # To avoid starvation, we yield at most max_yield items per # iterator before switching. if deterministic: max_yield = 1 # Forces round robin. else: max_yield = 20 try: for _ in range(max_yield): item = next(it) if isinstance(item, _NextValueNotReady): break else: yield item except StopIteration: active.remove(it) if not active: break return LocalIterator(build_union, shared_metrics, [], name="LocalUnion[{}, {}]".format( self, ", ".join(map(str, others))))
def gather_async(self) -> "LocalIterator[T]": """Returns a local iterable for asynchronous iteration. New items will be fetched from the shards asynchronously as soon as the previous one is computed. Items arrive in non-deterministic order. Examples: >>> it = from_range(100, 1).gather_async() >>> next(it) ... 3 >>> next(it) ... 0 >>> next(it) ... 1 """ metrics = MetricsContext() def base_iterator(timeout=None): all_actors = [] for actor_set in self.actor_sets: actor_set.init_actors() all_actors.extend(actor_set.actors) futures = {} for a in all_actors: futures[a.par_iter_next.remote()] = a while futures: pending = list(futures) if timeout is None: # First try to do a batch wait for efficiency. ready, _ = ray.wait(pending, num_returns=len(pending), timeout=0) # Fall back to a blocking wait. if not ready: ready, _ = ray.wait(pending, num_returns=1) else: ready, _ = ray.wait(pending, num_returns=len(pending), timeout=timeout) for obj_id in ready: actor = futures.pop(obj_id) try: metrics.cur_actor = actor yield ray.get(obj_id) futures[actor.par_iter_next.remote()] = actor except StopIteration: pass # Always yield after each round of wait with timeout. if timeout is not None: yield _NextValueNotReady() name = "{}.gather_async()".format(self) return LocalIterator(base_iterator, metrics, name=name)
def LocalReplay(replay_buffer, train_batch_size): replay_buffers = {DEFAULT_POLICY_ID: replay_buffer} # TODO(ekl) support more options synchronize_sampling = False prioritized_replay_beta = None def gen_replay(timeout): while True: samples = {} idxes = None for policy_id, replay_buffer in replay_buffers.items(): if synchronize_sampling: if idxes is None: idxes = replay_buffer.sample_idxes(train_batch_size) else: idxes = replay_buffer.sample_idxes(train_batch_size) if isinstance(replay_buffer, PrioritizedReplayBuffer): metrics = LocalIterator.get_metrics() num_steps_trained = metrics.counters[STEPS_TRAINED_COUNTER] (obses_t, actions, rewards, obses_tp1, dones, weights, batch_indexes) = replay_buffer.sample_with_idxes( idxes, beta=prioritized_replay_beta.value(num_steps_trained)) else: (obses_t, actions, rewards, obses_tp1, dones) = replay_buffer.sample_with_idxes(idxes) weights = np.ones_like(rewards) batch_indexes = -np.ones_like(rewards) samples[policy_id] = SampleBatch({ "obs": obses_t, "actions": actions, "rewards": rewards, "new_obs": obses_tp1, "dones": dones, "weights": weights, "batch_indexes": batch_indexes }) yield MultiAgentBatch(samples, train_batch_size) return LocalIterator(gen_replay, MetricsContext())
def batch_across_shards(self) -> "LocalIterator[List[T]]": """Iterate over the results of multiple shards in parallel. Examples: >>> it = from_iterators([range(3), range(3)]) >>> next(it.batch_across_shards()) ... [0, 0] """ def base_iterator(timeout=None): active = [] for actor_set in self.actor_sets: actor_set.init_actors() active.extend(actor_set.actors) futures = [a.par_iter_next.remote() for a in active] while active: try: yield ray.get(futures, timeout=timeout) futures = [a.par_iter_next.remote() for a in active] # Always yield after each round of gets with timeout. if timeout is not None: yield _NextValueNotReady() except TimeoutError: yield _NextValueNotReady() except StopIteration: # Find and remove the actor that produced StopIteration. results = [] for a, f in zip(list(active), futures): try: results.append(ray.get(f)) except StopIteration: active.remove(a) if results: yield results futures = [a.par_iter_next.remote() for a in active] name = "{}.batch_across_shards()".format(self) return LocalIterator(base_iterator, MetricsContext(), name=name)
def gather_async(self, async_queue_depth=1) -> "LocalIterator[T]": """Returns a local iterable for asynchronous iteration. New items will be fetched from the shards asynchronously as soon as the previous one is computed. Items arrive in non-deterministic order. Arguments: async_queue_depth (int): The max number of async requests in flight per actor. Increasing this improves the amount of pipeline parallelism in the iterator. Examples: >>> it = from_range(100, 1).gather_async() >>> next(it) ... 3 >>> next(it) ... 0 >>> next(it) ... 1 """ if async_queue_depth < 1: raise ValueError("queue depth must be positive") def base_iterator(timeout=None): metrics = LocalIterator.get_metrics() all_actors = [] for actor_set in self.actor_sets: actor_set.init_actors() all_actors.extend(actor_set.actors) futures = {} for _ in range(async_queue_depth): for a in all_actors: futures[a.par_iter_next.remote()] = a while futures: pending = list(futures) if timeout is None: # First try to do a batch wait for efficiency. ready, _ = ray.wait(pending, num_returns=len(pending), timeout=0) # Fall back to a blocking wait. if not ready: ready, _ = ray.wait(pending, num_returns=1) else: ready, _ = ray.wait(pending, num_returns=len(pending), timeout=timeout) for obj_id in ready: actor = futures.pop(obj_id) try: metrics.current_actor = actor yield ray.get(obj_id) futures[actor.par_iter_next.remote()] = actor except StopIteration: pass # Always yield after each round of wait with timeout. if timeout is not None: yield _NextValueNotReady() name = "{}.gather_async()".format(self) return LocalIterator(base_iterator, MetricsContext(), name=name)