class AsyncReplayOptimizer(PolicyOptimizer): """Main event loop of the Ape-X optimizer (async sampling with replay). This class coordinates the data transfers between the learner thread, remote evaluators (Ape-X actors), and replay buffer actors. This has two modes of operation: - normal replay: replays independent samples. - batch replay: simplified mode where entire sample batches are replayed. This supports RNNs, but not prioritization. This optimizer requires that policy evaluators return an additional "td_error" array in the info return of compute_gradients(). This error term will be used for sample prioritization.""" def __init__(self, local_evaluator, remote_evaluators, learning_starts=1000, buffer_size=10000, prioritized_replay=True, prioritized_replay_alpha=0.6, prioritized_replay_beta=0.4, prioritized_replay_eps=1e-6, train_batch_size=512, sample_batch_size=50, num_replay_buffer_shards=1, max_weight_sync_delay=400, debug=False, batch_replay=False): PolicyOptimizer.__init__(self, local_evaluator, remote_evaluators) self.debug = debug self.batch_replay = batch_replay self.replay_starts = learning_starts self.prioritized_replay_beta = prioritized_replay_beta self.prioritized_replay_eps = prioritized_replay_eps self.max_weight_sync_delay = max_weight_sync_delay self.learner = LearnerThread(self.local_evaluator) self.learner.start() if self.batch_replay: replay_cls = BatchReplayActor else: replay_cls = ReplayActor self.replay_actors = create_colocated(replay_cls, [ num_replay_buffer_shards, learning_starts, buffer_size, train_batch_size, prioritized_replay_alpha, prioritized_replay_beta, prioritized_replay_eps, ], num_replay_buffer_shards) # Stats self.timers = { k: TimerStat() for k in [ "put_weights", "get_samples", "sample_processing", "replay_processing", "update_priorities", "train", "sample" ] } self.num_weight_syncs = 0 self.num_samples_dropped = 0 self.learning_started = False # Number of worker steps since the last weight update self.steps_since_update = {} # Otherwise kick of replay tasks for local gradient updates self.replay_tasks = TaskPool() for ra in self.replay_actors: for _ in range(REPLAY_QUEUE_DEPTH): self.replay_tasks.add(ra, ra.replay.remote()) # Kick off async background sampling self.sample_tasks = TaskPool() if self.remote_evaluators: self._set_evaluators(self.remote_evaluators) @override(PolicyOptimizer) def step(self): assert self.learner.is_alive() assert len(self.remote_evaluators) > 0 start = time.time() sample_timesteps, train_timesteps = self._step() time_delta = time.time() - start self.timers["sample"].push(time_delta) self.timers["sample"].push_units_processed(sample_timesteps) if train_timesteps > 0: self.learning_started = True if self.learning_started: self.timers["train"].push(time_delta) self.timers["train"].push_units_processed(train_timesteps) self.num_steps_sampled += sample_timesteps self.num_steps_trained += train_timesteps @override(PolicyOptimizer) def stop(self): for r in self.replay_actors: r.__ray_terminate__.remote() self.learner.stopped = True @override(PolicyOptimizer) def reset(self, remote_evaluators): self.remote_evaluators = remote_evaluators self.sample_tasks.reset_evaluators(remote_evaluators) @override(PolicyOptimizer) def stats(self): replay_stats = ray.get(self.replay_actors[0].stats.remote(self.debug)) timing = { "{}_time_ms".format(k): round(1000 * self.timers[k].mean, 3) for k in self.timers } timing["learner_grad_time_ms"] = round( 1000 * self.learner.grad_timer.mean, 3) timing["learner_dequeue_time_ms"] = round( 1000 * self.learner.queue_timer.mean, 3) stats = { "sample_throughput": round(self.timers["sample"].mean_throughput, 3), "train_throughput": round(self.timers["train"].mean_throughput, 3), "num_weight_syncs": self.num_weight_syncs, "num_samples_dropped": self.num_samples_dropped, "learner_queue": self.learner.learner_queue_size.stats(), "replay_shard_0": replay_stats, } debug_stats = { "timing_breakdown": timing, "pending_sample_tasks": self.sample_tasks.count, "pending_replay_tasks": self.replay_tasks.count, } if self.debug: stats.update(debug_stats) if self.learner.stats: stats["learner"] = self.learner.stats return dict(PolicyOptimizer.stats(self), **stats) # For https://github.com/ray-project/ray/issues/2541 only def _set_evaluators(self, remote_evaluators): self.remote_evaluators = remote_evaluators weights = self.local_evaluator.get_weights() for ev in self.remote_evaluators: ev.set_weights.remote(weights) self.steps_since_update[ev] = 0 for _ in range(SAMPLE_QUEUE_DEPTH): self.sample_tasks.add(ev, ev.sample_with_count.remote()) def _step(self): sample_timesteps, train_timesteps = 0, 0 weights = None with self.timers["sample_processing"]: completed = list(self.sample_tasks.completed()) counts = ray.get([c[1][1] for c in completed]) for i, (ev, (sample_batch, count)) in enumerate(completed): sample_timesteps += counts[i] # Send the data to the replay buffer random.choice( self.replay_actors).add_batch.remote(sample_batch) # Update weights if needed self.steps_since_update[ev] += counts[i] if self.steps_since_update[ev] >= self.max_weight_sync_delay: # Note that it's important to pull new weights once # updated to avoid excessive correlation between actors if weights is None or self.learner.weights_updated: self.learner.weights_updated = False with self.timers["put_weights"]: weights = ray.put( self.local_evaluator.get_weights()) ev.set_weights.remote(weights) self.num_weight_syncs += 1 self.steps_since_update[ev] = 0 # Kick off another sample request self.sample_tasks.add(ev, ev.sample_with_count.remote()) with self.timers["replay_processing"]: for ra, replay in self.replay_tasks.completed(): self.replay_tasks.add(ra, ra.replay.remote()) if self.learner.inqueue.full(): self.num_samples_dropped += 1 else: with self.timers["get_samples"]: samples = ray.get(replay) # Defensive copy against plasma crashes, see #2610 #3452 self.learner.inqueue.put((ra, samples and samples.copy())) with self.timers["update_priorities"]: while not self.learner.outqueue.empty(): ra, prio_dict, count = self.learner.outqueue.get() ra.update_priorities.remote(prio_dict) train_timesteps += count return sample_timesteps, train_timesteps
class AggregationWorkerBase(object): """Aggregators should extend from this class.""" def __init__(self, initial_weights_obj_id, remote_evaluators, max_sample_requests_in_flight_per_worker, replay_proportion, replay_buffer_num_slots, train_batch_size, sample_batch_size): self.broadcasted_weights = initial_weights_obj_id self.remote_evaluators = remote_evaluators self.sample_batch_size = sample_batch_size self.train_batch_size = train_batch_size if replay_proportion: if replay_buffer_num_slots * sample_batch_size <= train_batch_size: raise ValueError( "Replay buffer size is too small to produce train, " "please increase replay_buffer_num_slots.", replay_buffer_num_slots, sample_batch_size, train_batch_size) # Kick off async background sampling self.sample_tasks = TaskPool() for ev in self.remote_evaluators: ev.set_weights.remote(self.broadcasted_weights) for _ in range(max_sample_requests_in_flight_per_worker): self.sample_tasks.add(ev, ev.sample.remote()) self.batch_buffer = [] self.replay_proportion = replay_proportion self.replay_buffer_num_slots = replay_buffer_num_slots self.replay_batches = [] self.num_sent_since_broadcast = 0 self.num_weight_syncs = 0 self.num_replayed = 0 @override(Aggregator) def iter_train_batches(self, max_yield=999): """Iterate over train batches. Arguments: max_yield (int): Max number of batches to iterate over in this cycle. Setting this avoids iter_train_batches returning too much data at once. """ for ev, sample_batch in self._augment_with_replay( self.sample_tasks.completed_prefetch(blocking_wait=True, max_yield=max_yield)): sample_batch.decompress_if_needed() self.batch_buffer.append(sample_batch) if sum(b.count for b in self.batch_buffer) >= self.train_batch_size: train_batch = self.batch_buffer[0].concat_samples( self.batch_buffer) yield train_batch self.batch_buffer = [] # If the batch was replayed, skip the update below. if ev is None: continue # Put in replay buffer if enabled if self.replay_buffer_num_slots > 0: self.replay_batches.append(sample_batch) if len(self.replay_batches) > self.replay_buffer_num_slots: self.replay_batches.pop(0) ev.set_weights.remote(self.broadcasted_weights) self.num_weight_syncs += 1 self.num_sent_since_broadcast += 1 # Kick off another sample request self.sample_tasks.add(ev, ev.sample.remote()) @override(Aggregator) def stats(self): return { "num_weight_syncs": self.num_weight_syncs, "num_steps_replayed": self.num_replayed, } @override(Aggregator) def reset(self, remote_evaluators): self.sample_tasks.reset_evaluators(remote_evaluators) def _augment_with_replay(self, sample_futures): def can_replay(): num_needed = int( np.ceil(self.train_batch_size / self.sample_batch_size)) return len(self.replay_batches) > num_needed for ev, sample_batch in sample_futures: sample_batch = ray.get(sample_batch) yield ev, sample_batch if can_replay(): f = self.replay_proportion while random.random() < f: f -= 1 replay_batch = random.choice(self.replay_batches) self.num_replayed += replay_batch.count yield None, replay_batch
class AsyncSamplesOptimizer(PolicyOptimizer): """Main event loop of the IMPALA architecture. This class coordinates the data transfers between the learner thread and remote evaluators (IMPALA actors). """ @override(PolicyOptimizer) def _init(self, train_batch_size=500, sample_batch_size=50, num_envs_per_worker=1, num_gpus=0, lr=0.0005, replay_buffer_num_slots=0, replay_proportion=0.0, num_data_loader_buffers=1, max_sample_requests_in_flight_per_worker=2, broadcast_interval=1, num_sgd_iter=1, minibatch_buffer_size=1, learner_queue_size=16, _fake_gpus=False): self.train_batch_size = train_batch_size self.sample_batch_size = sample_batch_size self.broadcast_interval = broadcast_interval self._stats_start_time = time.time() self._last_stats_time = {} self._last_stats_sum = {} if num_gpus > 1 or num_data_loader_buffers > 1: logger.info( "Enabling multi-GPU mode, {} GPUs, {} parallel loaders".format( num_gpus, num_data_loader_buffers)) if num_data_loader_buffers < minibatch_buffer_size: raise ValueError( "In multi-gpu mode you must have at least as many " "parallel data loader buffers as minibatch buffers: " "{} vs {}".format(num_data_loader_buffers, minibatch_buffer_size)) self.learner = TFMultiGPULearner( self.local_evaluator, lr=lr, num_gpus=num_gpus, train_batch_size=train_batch_size, num_data_loader_buffers=num_data_loader_buffers, minibatch_buffer_size=minibatch_buffer_size, num_sgd_iter=num_sgd_iter, learner_queue_size=learner_queue_size, _fake_gpus=_fake_gpus) else: self.learner = LearnerThread(self.local_evaluator, minibatch_buffer_size, num_sgd_iter, learner_queue_size) self.learner.start() # Stats self._optimizer_step_timer = TimerStat() self.num_weight_syncs = 0 self.num_replayed = 0 self._stats_start_time = time.time() self._last_stats_time = {} self._last_stats_val = {} # Kick off async background sampling self.sample_tasks = TaskPool() weights = self.local_evaluator.get_weights() for ev in self.remote_evaluators: ev.set_weights.remote(weights) for _ in range(max_sample_requests_in_flight_per_worker): self.sample_tasks.add(ev, ev.sample.remote()) self.batch_buffer = [] if replay_proportion: if replay_buffer_num_slots * sample_batch_size <= train_batch_size: raise ValueError( "Replay buffer size is too small to produce train, " "please increase replay_buffer_num_slots.", replay_buffer_num_slots, sample_batch_size, train_batch_size) self.replay_proportion = replay_proportion self.replay_buffer_num_slots = replay_buffer_num_slots self.replay_batches = [] def add_stat_val(self, key, val): if key not in self._last_stats_sum: self._last_stats_sum[key] = 0 self._last_stats_time[key] = self._stats_start_time self._last_stats_sum[key] += val def get_mean_stats_and_reset(self): now = time.time() mean_stats = { key: round(val / (now - self._last_stats_time[key]), 3) for key, val in self._last_stats_sum.items() } for key in self._last_stats_sum.keys(): self._last_stats_sum[key] = 0 self._last_stats_time[key] = time.time() return mean_stats @override(PolicyOptimizer) def step(self): if len(self.remote_evaluators) == 0: raise ValueError("Config num_workers=0 means training will hang!") assert self.learner.is_alive() with self._optimizer_step_timer: sample_timesteps, train_timesteps = self._step() if sample_timesteps > 0: self.add_stat_val("sample_throughput", sample_timesteps) if train_timesteps > 0: self.add_stat_val("train_throughput", train_timesteps) self.num_steps_sampled += sample_timesteps self.num_steps_trained += train_timesteps @override(PolicyOptimizer) def stop(self): self.learner.stopped = True @override(PolicyOptimizer) def reset(self, remote_evaluators): self.remote_evaluators = remote_evaluators self.sample_tasks.reset_evaluators(remote_evaluators) @override(PolicyOptimizer) def stats(self): def timer_to_ms(timer): return round(1000 * timer.mean, 3) timing = { "optimizer_step_time_ms": timer_to_ms(self._optimizer_step_timer), "learner_grad_time_ms": timer_to_ms(self.learner.grad_timer), "learner_load_time_ms": timer_to_ms(self.learner.load_timer), "learner_load_wait_time_ms": timer_to_ms(self.learner.load_wait_timer), "learner_dequeue_time_ms": timer_to_ms(self.learner.queue_timer), } stats = dict( { "num_weight_syncs": self.num_weight_syncs, "num_steps_replayed": self.num_replayed, "timing_breakdown": timing, "learner_queue": self.learner.learner_queue_size.stats(), }, **self.get_mean_stats_and_reset()) self._last_stats_val.clear() if self.learner.stats: stats["learner"] = self.learner.stats return dict(PolicyOptimizer.stats(self), **stats) def _step(self): sample_timesteps, train_timesteps = 0, 0 num_sent = 0 weights = None for ev, sample_batch in self._augment_with_replay( self.sample_tasks.completed_prefetch()): self.batch_buffer.append(sample_batch) if sum(b.count for b in self.batch_buffer) >= self.train_batch_size: train_batch = self.batch_buffer[0].concat_samples( self.batch_buffer) self.learner.inqueue.put(train_batch) self.batch_buffer = [] # If the batch was replayed, skip the update below. if ev is None: continue sample_timesteps += sample_batch.count # Put in replay buffer if enabled if self.replay_buffer_num_slots > 0: self.replay_batches.append(sample_batch) if len(self.replay_batches) > self.replay_buffer_num_slots: self.replay_batches.pop(0) # Note that it's important to pull new weights once # updated to avoid excessive correlation between actors if weights is None or (self.learner.weights_updated and num_sent >= self.broadcast_interval): self.learner.weights_updated = False weights = ray.put(self.local_evaluator.get_weights()) num_sent = 0 ev.set_weights.remote(weights) self.num_weight_syncs += 1 num_sent += 1 # Kick off another sample request self.sample_tasks.add(ev, ev.sample.remote()) while not self.learner.outqueue.empty(): count = self.learner.outqueue.get() train_timesteps += count return sample_timesteps, train_timesteps def _augment_with_replay(self, sample_futures): def can_replay(): num_needed = int( np.ceil(self.train_batch_size / self.sample_batch_size)) return len(self.replay_batches) > num_needed for ev, sample_batch in sample_futures: sample_batch = ray.get(sample_batch) yield ev, sample_batch if can_replay(): f = self.replay_proportion while random.random() < f: f -= 1 replay_batch = random.choice(self.replay_batches) self.num_replayed += replay_batch.count yield None, replay_batch