def test_reset_workers_pendingFetchesFromFailedWorkersRemoved(self, rayWaitMock): pool = TaskPool() # We need to hold onto the tasks for this test so that we can fail a # specific worker tasks = [] for i in range(10): task = createMockWorkerAndObjectRef(i) pool.add(*task) tasks.append(task) # Simulate only some of the work being complete and fetch a couple of # tasks in order to fill the fetching queue rayWaitMock.return_value = ([0, 1, 2, 3, 4, 5], [6, 7, 8, 9]) fetched = [pair[1] for pair in pool.completed_prefetch(max_yield=2)] # As we still have some pending tasks, we need to update the # completion states to remove the completed tasks rayWaitMock.return_value = ([], [6, 7, 8, 9]) pool.reset_workers( [ tasks[0][0], tasks[1][0], tasks[2][0], tasks[3][0], # OH NO! WORKER 4 HAS CRASHED! tasks[5][0], tasks[6][0], tasks[7][0], tasks[8][0], tasks[9][0], ] ) # Fetch the remaining tasks which should already be in the _fetching # queue fetched = [pair[1] for pair in pool.completed_prefetch()] self.assertListEqual(fetched, [2, 3, 5])
class AsyncReplayOptimizer(PolicyOptimizer): """Main event loop of the Ape-X optimizer (async sampling with replay). This class coordinates the data transfers between the learner thread, remote workers (Ape-X actors), and replay buffer actors. This has two modes of operation: - normal replay: replays independent samples. - batch replay: simplified mode where entire sample batches are replayed. This supports RNNs, but not prioritization. This optimizer requires that rollout workers return an additional "td_error" array in the info return of compute_gradients(). This error term will be used for sample prioritization.""" def __init__(self, workers, learning_starts=1000, buffer_size=10000, prioritized_replay=True, prioritized_replay_alpha=0.6, prioritized_replay_beta=0.4, prioritized_replay_eps=1e-6, train_batch_size=512, rollout_fragment_length=50, num_replay_buffer_shards=1, max_weight_sync_delay=400, debug=False, batch_replay=False): """Initialize an async replay optimizer. Arguments: workers (WorkerSet): all workers learning_starts (int): wait until this many steps have been sampled before starting optimization. buffer_size (int): max size of the replay buffer prioritized_replay (bool): whether to enable prioritized replay prioritized_replay_alpha (float): replay alpha hyperparameter prioritized_replay_beta (float): replay beta hyperparameter prioritized_replay_eps (float): replay eps hyperparameter train_batch_size (int): size of batches to learn on rollout_fragment_length (int): size of batches to sample from workers. num_replay_buffer_shards (int): number of actors to use to store replay samples max_weight_sync_delay (int): update the weights of a rollout worker after collecting this number of timesteps from it debug (bool): return extra debug stats batch_replay (bool): replay entire sequential batches of experiences instead of sampling steps individually """ PolicyOptimizer.__init__(self, workers) self.debug = debug self.batch_replay = batch_replay self.replay_starts = learning_starts self.prioritized_replay_beta = prioritized_replay_beta self.prioritized_replay_eps = prioritized_replay_eps self.max_weight_sync_delay = max_weight_sync_delay self.learner = LearnerThread(self.workers.local_worker()) self.learner.start() if self.batch_replay: replay_cls = BatchReplayActor else: replay_cls = ReplayActor self.replay_actors = create_colocated(replay_cls, [ num_replay_buffer_shards, learning_starts, buffer_size, train_batch_size, prioritized_replay_alpha, prioritized_replay_beta, prioritized_replay_eps, ], num_replay_buffer_shards) # Stats self.timers = { k: TimerStat() for k in [ "put_weights", "get_samples", "sample_processing", "replay_processing", "update_priorities", "train", "sample" ] } self.num_weight_syncs = 0 self.num_samples_dropped = 0 self.learning_started = False # Number of worker steps since the last weight update self.steps_since_update = {} # Otherwise kick of replay tasks for local gradient updates self.replay_tasks = TaskPool() for ra in self.replay_actors: for _ in range(REPLAY_QUEUE_DEPTH): self.replay_tasks.add(ra, ra.replay.remote()) # Kick off async background sampling self.sample_tasks = TaskPool() if self.workers.remote_workers(): self._set_workers(self.workers.remote_workers()) @override(PolicyOptimizer) def step(self): assert self.learner.is_alive() assert len(self.workers.remote_workers()) > 0 start = time.time() sample_timesteps, train_timesteps = self._step() time_delta = time.time() - start self.timers["sample"].push(time_delta) self.timers["sample"].push_units_processed(sample_timesteps) if train_timesteps > 0: self.learning_started = True if self.learning_started: self.timers["train"].push(time_delta) self.timers["train"].push_units_processed(train_timesteps) self.num_steps_sampled += sample_timesteps self.num_steps_trained += train_timesteps @override(PolicyOptimizer) def stop(self): for r in self.replay_actors: r.__ray_terminate__.remote() self.learner.stopped = True @override(PolicyOptimizer) def reset(self, remote_workers): self.workers.reset(remote_workers) self.sample_tasks.reset_workers(remote_workers) @override(PolicyOptimizer) def stats(self): replay_stats = ray_get_and_free(self.replay_actors[0].stats.remote( self.debug)) timing = { "{}_time_ms".format(k): round(1000 * self.timers[k].mean, 3) for k in self.timers } timing["learner_grad_time_ms"] = round( 1000 * self.learner.grad_timer.mean, 3) timing["learner_dequeue_time_ms"] = round( 1000 * self.learner.queue_timer.mean, 3) stats = { "sample_throughput": round(self.timers["sample"].mean_throughput, 3), "train_throughput": round(self.timers["train"].mean_throughput, 3), "num_weight_syncs": self.num_weight_syncs, "num_samples_dropped": self.num_samples_dropped, "learner_queue": self.learner.learner_queue_size.stats(), "replay_shard_0": replay_stats, } debug_stats = { "timing_breakdown": timing, "pending_sample_tasks": self.sample_tasks.count, "pending_replay_tasks": self.replay_tasks.count, } if self.debug: stats.update(debug_stats) if self.learner.stats: stats["learner"] = self.learner.stats return dict(PolicyOptimizer.stats(self), **stats) # For https://github.com/ray-project/ray/issues/2541 only def _set_workers(self, remote_workers): self.workers.reset(remote_workers) weights = self.workers.local_worker().get_weights() for ev in self.workers.remote_workers(): ev.set_weights.remote(weights) self.steps_since_update[ev] = 0 for _ in range(SAMPLE_QUEUE_DEPTH): self.sample_tasks.add(ev, ev.sample_with_count.remote()) def _step(self): sample_timesteps, train_timesteps = 0, 0 weights = None with self.timers["sample_processing"]: completed = list(self.sample_tasks.completed()) # First try a batched ray.get(). ray_error = None try: counts = { i: v for i, v in enumerate( ray_get_and_free([c[1][1] for c in completed])) } # If there are failed workers, try to recover the still good ones # (via non-batched ray.get()) and store the first error (to raise # later). except RayError: counts = {} for i, c in enumerate(completed): try: counts[i] = ray_get_and_free(c[1][1]) except RayError as e: logger.exception( "Error in completed task: {}".format(e)) ray_error = ray_error if ray_error is not None else e for i, (ev, (sample_batch, count)) in enumerate(completed): # Skip failed tasks. if i not in counts: continue sample_timesteps += counts[i] # Send the data to the replay buffer random.choice( self.replay_actors).add_batch.remote(sample_batch) # Update weights if needed. self.steps_since_update[ev] += counts[i] if self.steps_since_update[ev] >= self.max_weight_sync_delay: # Note that it's important to pull new weights once # updated to avoid excessive correlation between actors. if weights is None or self.learner.weights_updated: self.learner.weights_updated = False with self.timers["put_weights"]: weights = ray.put( self.workers.local_worker().get_weights()) ev.set_weights.remote(weights) self.num_weight_syncs += 1 self.steps_since_update[ev] = 0 # Kick off another sample request. self.sample_tasks.add(ev, ev.sample_with_count.remote()) # Now that all still good tasks have been kicked off again, # we can throw the error. if ray_error: raise ray_error with self.timers["replay_processing"]: for ra, replay in self.replay_tasks.completed(): self.replay_tasks.add(ra, ra.replay.remote()) if self.learner.inqueue.full(): self.num_samples_dropped += 1 else: with self.timers["get_samples"]: samples = ray_get_and_free(replay) # Defensive copy against plasma crashes, see #2610 #3452 self.learner.inqueue.put((ra, samples and samples.copy())) with self.timers["update_priorities"]: while not self.learner.outqueue.empty(): ra, prio_dict, count = self.learner.outqueue.get() ra.update_priorities.remote(prio_dict) train_timesteps += count return sample_timesteps, train_timesteps
class AsyncReplayOptimizer(PolicyOptimizer): """Main event loop of the Ape-X optimizer (async sampling with replay). This class coordinates the data transfers between the learner thread, remote workers (Ape-X actors), and replay buffer actors. This has two modes of operation: - normal replay: replays independent samples. - batch replay: simplified mode where entire sample batches are replayed. This supports RNNs, but not prioritization. This optimizer requires that rollout workers return an additional "td_error" array in the info return of compute_gradients(). This error term will be used for sample prioritization.""" def __init__(self, workers, learning_starts=1000, buffer_size=10000, prioritized_replay=True, prioritized_replay_alpha=0.6, prioritized_replay_beta=0.4, prioritized_replay_eps=1e-6, train_batch_size=512, sample_batch_size=50, num_replay_buffer_shards=1, max_weight_sync_delay=400, debug=False, batch_replay=False): PolicyOptimizer.__init__(self, workers) self.debug = debug self.batch_replay = batch_replay self.replay_starts = learning_starts self.prioritized_replay_beta = prioritized_replay_beta self.prioritized_replay_eps = prioritized_replay_eps self.max_weight_sync_delay = max_weight_sync_delay self.learner = LearnerThread(self.workers.local_worker()) self.learner.start() if self.batch_replay: replay_cls = BatchReplayActor else: replay_cls = ReplayActor self.replay_actors = create_colocated(replay_cls, [ num_replay_buffer_shards, learning_starts, buffer_size, train_batch_size, prioritized_replay_alpha, prioritized_replay_beta, prioritized_replay_eps, ], num_replay_buffer_shards) # Stats self.timers = { k: TimerStat() for k in [ "put_weights", "get_samples", "sample_processing", "replay_processing", "update_priorities", "train", "sample" ] } self.num_weight_syncs = 0 self.num_samples_dropped = 0 self.learning_started = False # Number of worker steps since the last weight update self.steps_since_update = {} # Otherwise kick of replay tasks for local gradient updates self.replay_tasks = TaskPool() for ra in self.replay_actors: for _ in range(REPLAY_QUEUE_DEPTH): self.replay_tasks.add(ra, ra.replay.remote()) # Kick off async background sampling self.sample_tasks = TaskPool() if self.workers.remote_workers(): self._set_workers(self.workers.remote_workers()) @override(PolicyOptimizer) def step(self): assert self.learner.is_alive() assert len(self.workers.remote_workers()) > 0 start = time.time() sample_timesteps, train_timesteps = self._step() time_delta = time.time() - start self.timers["sample"].push(time_delta) self.timers["sample"].push_units_processed(sample_timesteps) if train_timesteps > 0: self.learning_started = True if self.learning_started: self.timers["train"].push(time_delta) self.timers["train"].push_units_processed(train_timesteps) self.num_steps_sampled += sample_timesteps self.num_steps_trained += train_timesteps @override(PolicyOptimizer) def stop(self): for r in self.replay_actors: r.__ray_terminate__.remote() self.learner.stopped = True @override(PolicyOptimizer) def reset(self, remote_workers): self.workers.reset(remote_workers) self.sample_tasks.reset_workers(remote_workers) @override(PolicyOptimizer) def stats(self): replay_stats = ray_get_and_free(self.replay_actors[0].stats.remote( self.debug)) timing = { "{}_time_ms".format(k): round(1000 * self.timers[k].mean, 3) for k in self.timers } timing["learner_grad_time_ms"] = round( 1000 * self.learner.grad_timer.mean, 3) timing["learner_dequeue_time_ms"] = round( 1000 * self.learner.queue_timer.mean, 3) stats = { "sample_throughput": round(self.timers["sample"].mean_throughput, 3), "train_throughput": round(self.timers["train"].mean_throughput, 3), "num_weight_syncs": self.num_weight_syncs, "num_samples_dropped": self.num_samples_dropped, "learner_queue": self.learner.learner_queue_size.stats(), "replay_shard_0": replay_stats, } debug_stats = { "timing_breakdown": timing, "pending_sample_tasks": self.sample_tasks.count, "pending_replay_tasks": self.replay_tasks.count, } if self.debug: stats.update(debug_stats) if self.learner.stats: stats["learner"] = self.learner.stats return dict(PolicyOptimizer.stats(self), **stats) # For https://github.com/ray-project/ray/issues/2541 only def _set_workers(self, remote_workers): self.workers.reset(remote_workers) weights = self.workers.local_worker().get_weights() for ev in self.workers.remote_workers(): ev.set_weights.remote(weights) self.steps_since_update[ev] = 0 for _ in range(SAMPLE_QUEUE_DEPTH): self.sample_tasks.add(ev, ev.sample_with_count.remote()) def _step(self): sample_timesteps, train_timesteps = 0, 0 weights = None with self.timers["sample_processing"]: completed = list(self.sample_tasks.completed()) counts = ray_get_and_free([c[1][1] for c in completed]) for i, (ev, (sample_batch, count)) in enumerate(completed): sample_timesteps += counts[i] # Send the data to the replay buffer random.choice( self.replay_actors).add_batch.remote(sample_batch) # Update weights if needed self.steps_since_update[ev] += counts[i] if self.steps_since_update[ev] >= self.max_weight_sync_delay: # Note that it's important to pull new weights once # updated to avoid excessive correlation between actors if weights is None or self.learner.weights_updated: self.learner.weights_updated = False with self.timers["put_weights"]: weights = ray.put( self.workers.local_worker().get_weights()) ev.set_weights.remote(weights) self.num_weight_syncs += 1 self.steps_since_update[ev] = 0 # Kick off another sample request self.sample_tasks.add(ev, ev.sample_with_count.remote()) with self.timers["replay_processing"]: for ra, replay in self.replay_tasks.completed(): self.replay_tasks.add(ra, ra.replay.remote()) if self.learner.inqueue.full(): self.num_samples_dropped += 1 else: with self.timers["get_samples"]: samples = ray_get_and_free(replay) # Defensive copy against plasma crashes, see #2610 #3452 self.learner.inqueue.put((ra, samples and samples.copy())) with self.timers["update_priorities"]: while not self.learner.outqueue.empty(): ra, prio_dict, count = self.learner.outqueue.get() ra.update_priorities.remote(prio_dict) train_timesteps += count return sample_timesteps, train_timesteps
class AggregationWorkerBase: """Aggregators should extend from this class.""" def __init__(self, initial_weights_obj_id, remote_workers, max_sample_requests_in_flight_per_worker, replay_proportion, replay_buffer_num_slots, train_batch_size, rollout_fragment_length): """Initialize an aggregator. Arguments: initial_weights_obj_id (ObjectID): initial worker weights remote_workers (list): set of remote workers assigned to this agg max_sample_request_in_flight_per_worker (int): max queue size per worker replay_proportion (float): ratio of replay to sampled outputs replay_buffer_num_slots (int): max number of sample batches to store in the replay buffer train_batch_size (int): size of batches to learn on rollout_fragment_length (int): size of batches to sample from workers. """ self.broadcasted_weights = initial_weights_obj_id self.remote_workers = remote_workers self.rollout_fragment_length = rollout_fragment_length self.train_batch_size = train_batch_size if replay_proportion: if (replay_buffer_num_slots * rollout_fragment_length <= train_batch_size): raise ValueError( "Replay buffer size is too small to produce train, " "please increase replay_buffer_num_slots.", replay_buffer_num_slots, rollout_fragment_length, train_batch_size) # Kick off async background sampling self.sample_tasks = TaskPool() for ev in self.remote_workers: ev.set_weights.remote(self.broadcasted_weights) for _ in range(max_sample_requests_in_flight_per_worker): self.sample_tasks.add(ev, ev.sample.remote()) self.batch_buffer = [] self.replay_proportion = replay_proportion self.replay_buffer_num_slots = replay_buffer_num_slots self.replay_batches = [] self.replay_index = 0 self.num_sent_since_broadcast = 0 self.num_weight_syncs = 0 self.num_replayed = 0 @override(Aggregator) def iter_train_batches(self, max_yield=999): """Iterate over train batches. Arguments: max_yield (int): Max number of batches to iterate over in this cycle. Setting this avoids iter_train_batches returning too much data at once. """ for ev, sample_batch in self._augment_with_replay( self.sample_tasks.completed_prefetch(blocking_wait=True, max_yield=max_yield)): sample_batch.decompress_if_needed() self.batch_buffer.append(sample_batch) if sum(b.count for b in self.batch_buffer) >= self.train_batch_size: if len(self.batch_buffer) == 1: # make a defensive copy to avoid sharing plasma memory # across multiple threads train_batch = self.batch_buffer[0].copy() else: train_batch = self.batch_buffer[0].concat_samples( self.batch_buffer) yield train_batch self.batch_buffer = [] # If the batch was replayed, skip the update below. if ev is None: continue # Put in replay buffer if enabled if self.replay_buffer_num_slots > 0: if len(self.replay_batches) < self.replay_buffer_num_slots: self.replay_batches.append(sample_batch) else: self.replay_batches[self.replay_index] = sample_batch self.replay_index += 1 self.replay_index %= self.replay_buffer_num_slots ev.set_weights.remote(self.broadcasted_weights) self.num_weight_syncs += 1 self.num_sent_since_broadcast += 1 # Kick off another sample request self.sample_tasks.add(ev, ev.sample.remote()) @override(Aggregator) def stats(self): return { "num_weight_syncs": self.num_weight_syncs, "num_steps_replayed": self.num_replayed, } @override(Aggregator) def reset(self, remote_workers): self.sample_tasks.reset_workers(remote_workers) def _augment_with_replay(self, sample_futures): def can_replay(): num_needed = int( np.ceil(self.train_batch_size / self.rollout_fragment_length)) return len(self.replay_batches) > num_needed for ev, sample_batch in sample_futures: sample_batch = ray.get(sample_batch) yield ev, sample_batch if can_replay(): f = self.replay_proportion while random.random() < f: f -= 1 replay_batch = random.choice(self.replay_batches) self.num_replayed += replay_batch.count yield None, replay_batch
class DRAggregatorBase: """Aggregators should extend from this class.""" def __init__(self, initial_weights_obj_id, remote_workers, max_sample_requests_in_flight_per_worker, replay_proportion, replay_buffer_num_slots, train_batch_size, sample_batch_size, sync_sampling=False): """Initialize an aggregator. Arguments: initial_weights_obj_id (ObjectID): initial worker weights remote_workers (list): set of remote workers assigned to this agg max_sample_request_in_flight_per_worker (int): max queue size per worker replay_proportion (float): ratio of replay to sampled outputs replay_buffer_num_slots (int): max number of sample batches to store in the replay buffer train_batch_size (int): size of batches to learn on sample_batch_size (int): size of batches to sample from workers """ self.broadcasted_weights = initial_weights_obj_id self.remote_workers = remote_workers self.sample_batch_size = sample_batch_size self.train_batch_size = train_batch_size if replay_proportion: if replay_buffer_num_slots * sample_batch_size <= train_batch_size: raise ValueError( "Replay buffer size is too small to produce train, " "please increase replay_buffer_num_slots.", replay_buffer_num_slots, sample_batch_size, train_batch_size) self.batch_buffer = [] self.replay_proportion = replay_proportion self.replay_buffer_num_slots = replay_buffer_num_slots self.max_sample_requests_in_flight_per_worker = \ max_sample_requests_in_flight_per_worker self.started = False self.sample_tasks = TaskPool() self.replay_batches = [] self.replay_index = 0 self.num_sent_since_broadcast = 0 self.num_weight_syncs = 0 self.num_replayed = 0 self.sample_timesteps = 0 self.sync_sampling = sync_sampling def start(self): # Kick off async background sampling for ev in self.remote_workers: ev.set_weights.remote(self.broadcasted_weights) for _ in range(self.max_sample_requests_in_flight_per_worker): self.sample_tasks.add(ev, ev.sample.remote()) self.started = True @override(Aggregator) def iter_train_batches(self, max_yield=999, _recursive_called=False): """Iterate over train batches. Arguments: force_yield_all (bool): Whether to return all batches until task pool is drained. max_yield (int): Max number of batches to iterate over in this cycle. Setting this avoids iter_train_batches returning too much data at once. """ assert self.started already_sent_out = False # ev is the rollout worker for ev, sample_batch in self._augment_with_replay( self.sample_tasks.completed_prefetch(blocking_wait=True, max_yield=max_yield)): sample_batch.decompress_if_needed() self.batch_buffer.append(sample_batch) if sum(b.count for b in self.batch_buffer) >= self.train_batch_size: if len(self.batch_buffer) == 1: # make a defensive copy to avoid sharing plasma memory # across multiple threads train_batch = self.batch_buffer[0].copy() else: train_batch = self.batch_buffer[0].concat_samples( self.batch_buffer) self.sample_timesteps += train_batch.count if self.sync_sampling: # If sync sampling is set, return batch and then stop. # You need to call start at the outside. # return [train_batch] already_sent_out = True yield train_batch else: yield train_batch self.batch_buffer = [] # If the batch was replayed, skip the update below. if ev is None: continue # Put in replay buffer if enabled if self.replay_buffer_num_slots > 0: if len(self.replay_batches) < self.replay_buffer_num_slots: self.replay_batches.append(sample_batch) else: self.replay_batches[self.replay_index] = sample_batch self.replay_index += 1 self.replay_index %= self.replay_buffer_num_slots if already_sent_out and self.sync_sampling: continue ev.set_weights.remote(self.broadcasted_weights) self.num_weight_syncs += 1 self.num_sent_since_broadcast += 1 # Kick off another sample request self.sample_tasks.add(ev, ev.sample.remote()) if self.sync_sampling and (not _recursive_called): while self.sample_tasks.count > 0: # A tricky way to force exhaust the task pool for train_batch in self.iter_train_batches( max_yield, _recursive_called=True): yield train_batch @override(Aggregator) def stats(self): return { "num_weight_syncs": self.num_weight_syncs, "num_steps_replayed": self.num_replayed, "sample_timesteps": self.sample_timesteps } @override(Aggregator) def reset(self, remote_workers): self.sample_tasks.reset_workers(remote_workers) def _augment_with_replay(self, sample_futures): def can_replay(): num_needed = int( np.ceil(self.train_batch_size / self.sample_batch_size)) return len(self.replay_batches) > num_needed for ev, sample_batch in sample_futures: sample_batch = ray_get_and_free(sample_batch) yield ev, sample_batch if can_replay(): f = self.replay_proportion while random.random() < f: f -= 1 replay_batch = random.choice(self.replay_batches) self.num_replayed += replay_batch.count yield None, replay_batch