class AsyncReplayOptimizer(PolicyOptimizer): """Main event loop of the Ape-X optimizer (async sampling with replay). This class coordinates the data transfers between the learner thread, remote workers (Ape-X actors), and replay buffer actors. This has two modes of operation: - normal replay: replays independent samples. - batch replay: simplified mode where entire sample batches are replayed. This supports RNNs, but not prioritization. This optimizer requires that rollout workers return an additional "td_error" array in the info return of compute_gradients(). This error term will be used for sample prioritization.""" def __init__(self, workers, learning_starts=1000, buffer_size=10000, prioritized_replay=True, prioritized_replay_alpha=0.6, prioritized_replay_beta=0.4, prioritized_replay_eps=1e-6, train_batch_size=512, rollout_fragment_length=50, num_replay_buffer_shards=1, max_weight_sync_delay=400, debug=False, batch_replay=False): """Initialize an async replay optimizer. Arguments: workers (WorkerSet): all workers learning_starts (int): wait until this many steps have been sampled before starting optimization. buffer_size (int): max size of the replay buffer prioritized_replay (bool): whether to enable prioritized replay prioritized_replay_alpha (float): replay alpha hyperparameter prioritized_replay_beta (float): replay beta hyperparameter prioritized_replay_eps (float): replay eps hyperparameter train_batch_size (int): size of batches to learn on rollout_fragment_length (int): size of batches to sample from workers. num_replay_buffer_shards (int): number of actors to use to store replay samples max_weight_sync_delay (int): update the weights of a rollout worker after collecting this number of timesteps from it debug (bool): return extra debug stats batch_replay (bool): replay entire sequential batches of experiences instead of sampling steps individually """ PolicyOptimizer.__init__(self, workers) self.debug = debug self.batch_replay = batch_replay self.replay_starts = learning_starts self.prioritized_replay_beta = prioritized_replay_beta self.prioritized_replay_eps = prioritized_replay_eps self.max_weight_sync_delay = max_weight_sync_delay self.learner = LearnerThread(self.workers.local_worker()) self.learner.start() if self.batch_replay: replay_cls = BatchReplayActor else: replay_cls = ReplayActor self.replay_actors = create_colocated(replay_cls, [ num_replay_buffer_shards, learning_starts, buffer_size, train_batch_size, prioritized_replay_alpha, prioritized_replay_beta, prioritized_replay_eps, ], num_replay_buffer_shards) # Stats self.timers = { k: TimerStat() for k in [ "put_weights", "get_samples", "sample_processing", "replay_processing", "update_priorities", "train", "sample" ] } self.num_weight_syncs = 0 self.num_samples_dropped = 0 self.learning_started = False # Number of worker steps since the last weight update self.steps_since_update = {} # Otherwise kick of replay tasks for local gradient updates self.replay_tasks = TaskPool() for ra in self.replay_actors: for _ in range(REPLAY_QUEUE_DEPTH): self.replay_tasks.add(ra, ra.replay.remote()) # Kick off async background sampling self.sample_tasks = TaskPool() if self.workers.remote_workers(): self._set_workers(self.workers.remote_workers()) @override(PolicyOptimizer) def step(self): assert self.learner.is_alive() assert len(self.workers.remote_workers()) > 0 start = time.time() sample_timesteps, train_timesteps = self._step() time_delta = time.time() - start self.timers["sample"].push(time_delta) self.timers["sample"].push_units_processed(sample_timesteps) if train_timesteps > 0: self.learning_started = True if self.learning_started: self.timers["train"].push(time_delta) self.timers["train"].push_units_processed(train_timesteps) self.num_steps_sampled += sample_timesteps self.num_steps_trained += train_timesteps @override(PolicyOptimizer) def stop(self): for r in self.replay_actors: r.__ray_terminate__.remote() self.learner.stopped = True @override(PolicyOptimizer) def reset(self, remote_workers): self.workers.reset(remote_workers) self.sample_tasks.reset_workers(remote_workers) @override(PolicyOptimizer) def stats(self): replay_stats = ray_get_and_free(self.replay_actors[0].stats.remote( self.debug)) timing = { "{}_time_ms".format(k): round(1000 * self.timers[k].mean, 3) for k in self.timers } timing["learner_grad_time_ms"] = round( 1000 * self.learner.grad_timer.mean, 3) timing["learner_dequeue_time_ms"] = round( 1000 * self.learner.queue_timer.mean, 3) stats = { "sample_throughput": round(self.timers["sample"].mean_throughput, 3), "train_throughput": round(self.timers["train"].mean_throughput, 3), "num_weight_syncs": self.num_weight_syncs, "num_samples_dropped": self.num_samples_dropped, "learner_queue": self.learner.learner_queue_size.stats(), "replay_shard_0": replay_stats, } debug_stats = { "timing_breakdown": timing, "pending_sample_tasks": self.sample_tasks.count, "pending_replay_tasks": self.replay_tasks.count, } if self.debug: stats.update(debug_stats) if self.learner.stats: stats["learner"] = self.learner.stats return dict(PolicyOptimizer.stats(self), **stats) # For https://github.com/ray-project/ray/issues/2541 only def _set_workers(self, remote_workers): self.workers.reset(remote_workers) weights = self.workers.local_worker().get_weights() for ev in self.workers.remote_workers(): ev.set_weights.remote(weights) self.steps_since_update[ev] = 0 for _ in range(SAMPLE_QUEUE_DEPTH): self.sample_tasks.add(ev, ev.sample_with_count.remote()) def _step(self): sample_timesteps, train_timesteps = 0, 0 weights = None with self.timers["sample_processing"]: completed = list(self.sample_tasks.completed()) # First try a batched ray.get(). ray_error = None try: counts = { i: v for i, v in enumerate( ray_get_and_free([c[1][1] for c in completed])) } # If there are failed workers, try to recover the still good ones # (via non-batched ray.get()) and store the first error (to raise # later). except RayError: counts = {} for i, c in enumerate(completed): try: counts[i] = ray_get_and_free(c[1][1]) except RayError as e: logger.exception( "Error in completed task: {}".format(e)) ray_error = ray_error if ray_error is not None else e for i, (ev, (sample_batch, count)) in enumerate(completed): # Skip failed tasks. if i not in counts: continue sample_timesteps += counts[i] # Send the data to the replay buffer random.choice( self.replay_actors).add_batch.remote(sample_batch) # Update weights if needed. self.steps_since_update[ev] += counts[i] if self.steps_since_update[ev] >= self.max_weight_sync_delay: # Note that it's important to pull new weights once # updated to avoid excessive correlation between actors. if weights is None or self.learner.weights_updated: self.learner.weights_updated = False with self.timers["put_weights"]: weights = ray.put( self.workers.local_worker().get_weights()) ev.set_weights.remote(weights) self.num_weight_syncs += 1 self.steps_since_update[ev] = 0 # Kick off another sample request. self.sample_tasks.add(ev, ev.sample_with_count.remote()) # Now that all still good tasks have been kicked off again, # we can throw the error. if ray_error: raise ray_error with self.timers["replay_processing"]: for ra, replay in self.replay_tasks.completed(): self.replay_tasks.add(ra, ra.replay.remote()) if self.learner.inqueue.full(): self.num_samples_dropped += 1 else: with self.timers["get_samples"]: samples = ray_get_and_free(replay) # Defensive copy against plasma crashes, see #2610 #3452 self.learner.inqueue.put((ra, samples and samples.copy())) with self.timers["update_priorities"]: while not self.learner.outqueue.empty(): ra, prio_dict, count = self.learner.outqueue.get() ra.update_priorities.remote(prio_dict) train_timesteps += count return sample_timesteps, train_timesteps
class AsyncReplayOptimizer(PolicyOptimizer): """Main event loop of the Ape-X optimizer (async sampling with replay). This class coordinates the data transfers between the learner thread, remote evaluators (Ape-X actors), and replay buffer actors. This has two modes of operation: - normal replay: replays independent samples. - batch replay: simplified mode where entire sample batches are replayed. This supports RNNs, but not prioritization. This optimizer requires that policy evaluators return an additional "td_error" array in the info return of compute_gradients(). This error term will be used for sample prioritization.""" @override(PolicyOptimizer) def _init(self, learning_starts=1000, buffer_size=10000, prioritized_replay=True, prioritized_replay_alpha=0.6, prioritized_replay_beta=0.4, prioritized_replay_eps=1e-6, train_batch_size=512, sample_batch_size=50, num_replay_buffer_shards=1, max_weight_sync_delay=400, debug=False, batch_replay=False): self.debug = debug self.batch_replay = batch_replay self.replay_starts = learning_starts self.prioritized_replay_beta = prioritized_replay_beta self.prioritized_replay_eps = prioritized_replay_eps self.max_weight_sync_delay = max_weight_sync_delay self.learner = LearnerThread(self.local_evaluator) self.learner.start() if self.batch_replay: replay_cls = BatchReplayActor else: replay_cls = ReplayActor self.replay_actors = create_colocated(replay_cls, [ num_replay_buffer_shards, learning_starts, buffer_size, train_batch_size, prioritized_replay_alpha, prioritized_replay_beta, prioritized_replay_eps, ], num_replay_buffer_shards) # Stats self.timers = { k: TimerStat() for k in [ "put_weights", "get_samples", "sample_processing", "replay_processing", "update_priorities", "train", "sample" ] } self.num_weight_syncs = 0 self.num_samples_dropped = 0 self.learning_started = False # Number of worker steps since the last weight update self.steps_since_update = {} # Otherwise kick of replay tasks for local gradient updates self.replay_tasks = TaskPool() for ra in self.replay_actors: for _ in range(REPLAY_QUEUE_DEPTH): self.replay_tasks.add(ra, ra.replay.remote()) # Kick off async background sampling self.sample_tasks = TaskPool() if self.remote_evaluators: self._set_evaluators(self.remote_evaluators) @override(PolicyOptimizer) def step(self): assert self.learner.is_alive() assert len(self.remote_evaluators) > 0 start = time.time() sample_timesteps, train_timesteps = self._step() time_delta = time.time() - start self.timers["sample"].push(time_delta) self.timers["sample"].push_units_processed(sample_timesteps) if train_timesteps > 0: self.learning_started = True if self.learning_started: self.timers["train"].push(time_delta) self.timers["train"].push_units_processed(train_timesteps) self.num_steps_sampled += sample_timesteps self.num_steps_trained += train_timesteps @override(PolicyOptimizer) def stop(self): for r in self.replay_actors: r.__ray_terminate__.remote() self.learner.stopped = True @override(PolicyOptimizer) def stats(self): replay_stats = ray.get(self.replay_actors[0].stats.remote(self.debug)) timing = { "{}_time_ms".format(k): round(1000 * self.timers[k].mean, 3) for k in self.timers } timing["learner_grad_time_ms"] = round( 1000 * self.learner.grad_timer.mean, 3) timing["learner_dequeue_time_ms"] = round( 1000 * self.learner.queue_timer.mean, 3) stats = { "sample_throughput": round(self.timers["sample"].mean_throughput, 3), "train_throughput": round(self.timers["train"].mean_throughput, 3), "num_weight_syncs": self.num_weight_syncs, "num_samples_dropped": self.num_samples_dropped, "learner_queue": self.learner.learner_queue_size.stats(), "replay_shard_0": replay_stats, } debug_stats = { "timing_breakdown": timing, "pending_sample_tasks": self.sample_tasks.count, "pending_replay_tasks": self.replay_tasks.count, } if self.debug: stats.update(debug_stats) if self.learner.stats: stats["learner"] = self.learner.stats return dict(PolicyOptimizer.stats(self), **stats) # For https://github.com/ray-project/ray/issues/2541 only def _set_evaluators(self, remote_evaluators): self.remote_evaluators = remote_evaluators weights = self.local_evaluator.get_weights() for ev in self.remote_evaluators: ev.set_weights.remote(weights) self.steps_since_update[ev] = 0 for _ in range(SAMPLE_QUEUE_DEPTH): self.sample_tasks.add(ev, ev.sample_with_count.remote()) def _step(self): sample_timesteps, train_timesteps = 0, 0 weights = None with self.timers["sample_processing"]: completed = list(self.sample_tasks.completed()) counts = ray.get([c[1][1] for c in completed]) for i, (ev, (sample_batch, count)) in enumerate(completed): sample_timesteps += counts[i] # Send the data to the replay buffer random.choice( self.replay_actors).add_batch.remote(sample_batch) # Update weights if needed self.steps_since_update[ev] += counts[i] if self.steps_since_update[ev] >= self.max_weight_sync_delay: # Note that it's important to pull new weights once # updated to avoid excessive correlation between actors if weights is None or self.learner.weights_updated: self.learner.weights_updated = False with self.timers["put_weights"]: weights = ray.put( self.local_evaluator.get_weights()) ev.set_weights.remote(weights) self.num_weight_syncs += 1 self.steps_since_update[ev] = 0 # Kick off another sample request self.sample_tasks.add(ev, ev.sample_with_count.remote()) with self.timers["replay_processing"]: for ra, replay in self.replay_tasks.completed(): self.replay_tasks.add(ra, ra.replay.remote()) if self.learner.inqueue.full(): self.num_samples_dropped += 1 else: with self.timers["get_samples"]: samples = ray.get(replay) # Defensive copy against plasma crashes, see #2610 #3452 self.learner.inqueue.put((ra, samples and samples.copy())) with self.timers["update_priorities"]: while not self.learner.outqueue.empty(): ra, prio_dict, count = self.learner.outqueue.get() ra.update_priorities.remote(prio_dict) train_timesteps += count return sample_timesteps, train_timesteps
class AsyncReplayOptimizer(PolicyOptimizer): """Main event loop of the Ape-X optimizer (async sampling with replay). This class coordinates the data transfers between the learner thread, remote evaluators (Ape-X actors), and replay buffer actors. This optimizer requires that policy evaluators return an additional "td_error" array in the info return of compute_gradients(). This error term will be used for sample prioritization.""" def _init(self, learning_starts=1000, buffer_size=10000, prioritized_replay=True, prioritized_replay_alpha=0.6, prioritized_replay_beta=0.4, prioritized_replay_eps=1e-6, train_batch_size=512, sample_batch_size=50, num_replay_buffer_shards=1, max_weight_sync_delay=400, debug=False): self.debug = debug self.replay_starts = learning_starts self.prioritized_replay_beta = prioritized_replay_beta self.prioritized_replay_eps = prioritized_replay_eps self.max_weight_sync_delay = max_weight_sync_delay self.learner = LearnerThread(self.local_evaluator) self.learner.start() self.replay_actors = create_colocated(ReplayActor, [ num_replay_buffer_shards, learning_starts, buffer_size, train_batch_size, prioritized_replay_alpha, prioritized_replay_beta, prioritized_replay_eps, ], num_replay_buffer_shards) # Stats self.timers = { k: TimerStat() for k in [ "put_weights", "get_samples", "enqueue", "sample_processing", "replay_processing", "update_priorities", "train", "sample" ] } self.num_weight_syncs = 0 self.learning_started = False # Number of worker steps since the last weight update self.steps_since_update = {} # Otherwise kick of replay tasks for local gradient updates self.replay_tasks = TaskPool() for ra in self.replay_actors: for _ in range(REPLAY_QUEUE_DEPTH): self.replay_tasks.add(ra, ra.replay.remote()) # Kick off async background sampling self.sample_tasks = TaskPool() if self.remote_evaluators: self.set_evaluators(self.remote_evaluators) # For https://github.com/ray-project/ray/issues/2541 only def set_evaluators(self, remote_evaluators): self.remote_evaluators = remote_evaluators weights = self.local_evaluator.get_weights() for ev in self.remote_evaluators: ev.set_weights.remote(weights) self.steps_since_update[ev] = 0 for _ in range(SAMPLE_QUEUE_DEPTH): self.sample_tasks.add(ev, ev.sample_with_count.remote()) def step(self): assert len(self.remote_evaluators) > 0 start = time.time() sample_timesteps, train_timesteps = self._step() time_delta = time.time() - start self.timers["sample"].push(time_delta) self.timers["sample"].push_units_processed(sample_timesteps) if train_timesteps > 0: self.learning_started = True if self.learning_started: self.timers["train"].push(time_delta) self.timers["train"].push_units_processed(train_timesteps) self.num_steps_sampled += sample_timesteps self.num_steps_trained += train_timesteps def _step(self): sample_timesteps, train_timesteps = 0, 0 weights = None with self.timers["sample_processing"]: completed = list(self.sample_tasks.completed()) counts = ray.get([c[1][1] for c in completed]) for i, (ev, (sample_batch, count)) in enumerate(completed): sample_timesteps += counts[i] # Send the data to the replay buffer random.choice( self.replay_actors).add_batch.remote(sample_batch) # Update weights if needed self.steps_since_update[ev] += counts[i] if self.steps_since_update[ev] >= self.max_weight_sync_delay: # Note that it's important to pull new weights once # updated to avoid excessive correlation between actors if weights is None or self.learner.weights_updated: self.learner.weights_updated = False with self.timers["put_weights"]: weights = ray.put( self.local_evaluator.get_weights()) ev.set_weights.remote(weights) self.num_weight_syncs += 1 self.steps_since_update[ev] = 0 # Kick off another sample request self.sample_tasks.add(ev, ev.sample_with_count.remote()) with self.timers["replay_processing"]: for ra, replay in self.replay_tasks.completed(): self.replay_tasks.add(ra, ra.replay.remote()) with self.timers["get_samples"]: samples = ray.get(replay) with self.timers["enqueue"]: self.learner.inqueue.put((ra, samples)) with self.timers["update_priorities"]: while not self.learner.outqueue.empty(): ra, replay, td_error, count = self.learner.outqueue.get() ra.update_priorities.remote(replay["batch_indexes"], td_error) train_timesteps += count return sample_timesteps, train_timesteps def stats(self): replay_stats = ray.get(self.replay_actors[0].stats.remote()) timing = { "{}_time_ms".format(k): round(1000 * self.timers[k].mean, 3) for k in self.timers } timing["learner_grad_time_ms"] = round( 1000 * self.learner.grad_timer.mean, 3) timing["learner_dequeue_time_ms"] = round( 1000 * self.learner.queue_timer.mean, 3) stats = { "sample_throughput": round(self.timers["sample"].mean_throughput, 3), "train_throughput": round(self.timers["train"].mean_throughput, 3), "num_weight_syncs": self.num_weight_syncs, } debug_stats = { "replay_shard_0": replay_stats, "timing_breakdown": timing, "pending_sample_tasks": self.sample_tasks.count, "pending_replay_tasks": self.replay_tasks.count, "learner_queue": self.learner.learner_queue_size.stats(), } if self.debug: stats.update(debug_stats) return dict(PolicyOptimizer.stats(self), **stats)
class ApexOptimizer(Optimizer): def _init(self, learning_starts=1000, buffer_size=10000, prioritized_replay=True, prioritized_replay_alpha=0.6, prioritized_replay_beta=0.4, prioritized_replay_eps=1e-6, train_batch_size=512, sample_batch_size=50, num_replay_buffer_shards=1, max_weight_sync_delay=400): self.replay_starts = learning_starts self.prioritized_replay_beta = prioritized_replay_beta self.prioritized_replay_eps = prioritized_replay_eps self.train_batch_size = train_batch_size self.sample_batch_size = sample_batch_size self.max_weight_sync_delay = max_weight_sync_delay self.learner = GenericLearner(self.local_evaluator) self.learner.start() self.replay_actors = create_colocated(ReplayActor, [ num_replay_buffer_shards, learning_starts, buffer_size, train_batch_size, prioritized_replay_alpha, prioritized_replay_beta, prioritized_replay_eps ], num_replay_buffer_shards) assert len(self.remote_evaluators) > 0 # Stats self.timers = { k: TimerStat() for k in [ "put_weights", "get_samples", "enqueue", "sample_processing", "replay_processing", "update_priorities", "train", "sample" ] } self.meters = { k: WindowStat(k, 10) for k in [ "samples_per_loop", "replays_per_loop", "reprios_per_loop", "reweights_per_loop" ] } self.num_weight_syncs = 0 self.learning_started = False # Number of worker steps since the last weight update self.steps_since_update = {} # Otherwise kick of replay tasks for local gradient updates self.replay_tasks = TaskPool() for ra in self.replay_actors: for _ in range(REPLAY_QUEUE_DEPTH): self.replay_tasks.add(ra, ra.replay.remote()) # Kick off async background sampling self.sample_tasks = TaskPool() weights = self.local_evaluator.get_weights() for ev in self.remote_evaluators: ev.set_weights.remote(weights) self.steps_since_update[ev] = 0 for _ in range(SAMPLE_QUEUE_DEPTH): self.sample_tasks.add(ev, ev.sample.remote()) def step(self): start = time.time() sample_timesteps, train_timesteps = self._step() time_delta = time.time() - start self.timers["sample"].push(time_delta) self.timers["sample"].push_units_processed(sample_timesteps) if train_timesteps > 0: self.learning_started = True if self.learning_started: self.timers["train"].push(time_delta) self.timers["train"].push_units_processed(train_timesteps) self.num_steps_sampled += sample_timesteps self.num_steps_trained += train_timesteps def _step(self): sample_timesteps, train_timesteps = 0, 0 weights = None with self.timers["sample_processing"]: i = 0 num_weight_syncs = 0 for ev, sample_batch in self.sample_tasks.completed(): i += 1 sample_timesteps += self.sample_batch_size # Send the data to the replay buffer random.choice( self.replay_actors).add_batch.remote(sample_batch) # Update weights if needed self.steps_since_update[ev] += self.sample_batch_size if self.steps_since_update[ev] >= self.max_weight_sync_delay: if weights is None: with self.timers["put_weights"]: weights = ray.put( self.local_evaluator.get_weights()) ev.set_weights.remote(weights) self.num_weight_syncs += 1 num_weight_syncs += 1 self.steps_since_update[ev] = 0 # Kick off another sample request self.sample_tasks.add(ev, ev.sample.remote()) self.meters["samples_per_loop"].push(i) self.meters["reweights_per_loop"].push(num_weight_syncs) with self.timers["replay_processing"]: i = 0 for ra, replay in self.replay_tasks.completed(): i += 1 self.replay_tasks.add(ra, ra.replay.remote()) with self.timers["get_samples"]: samples = ray.get(replay) with self.timers["enqueue"]: self.learner.inqueue.put((ra, samples)) self.meters["replays_per_loop"].push(i) with self.timers["update_priorities"]: i = 0 while not self.learner.outqueue.empty(): i += 1 ra, replay, td_error = self.learner.outqueue.get() ra.update_priorities.remote(replay, td_error) train_timesteps += self.train_batch_size self.meters["reprios_per_loop"].push(i) return sample_timesteps, train_timesteps def stats(self): replay_stats = ray.get(self.replay_actors[0].stats.remote()) timing = { "{}_time_ms".format(k): round(1000 * self.timers[k].mean, 3) for k in self.timers } timing["learner_grad_time_ms"] = round( 1000 * self.learner.grad_timer.mean, 3) timing["learner_dequeue_time_ms"] = round( 1000 * self.learner.queue_timer.mean, 3) stats = { "replay_shard_0": replay_stats, "timing_breakdown": timing, "sample_throughput": round(self.timers["sample"].mean_throughput, 3), "train_throughput": round(self.timers["train"].mean_throughput, 3), "num_weight_syncs": self.num_weight_syncs, "pending_sample_tasks": self.sample_tasks.count, "pending_replay_tasks": self.replay_tasks.count, "learner_queue": self.learner.learner_queue_size.stats(), "samples": self.meters["samples_per_loop"].stats(), "replays": self.meters["replays_per_loop"].stats(), "reprios": self.meters["reprios_per_loop"].stats(), "reweights": self.meters["reweights_per_loop"].stats(), } return dict(Optimizer.stats(self), **stats)