def init(self, aggregators): """Deferred init so that we can pass in previously created workers.""" assert len(aggregators) == self.num_aggregation_workers, aggregators if len(self.workers.remote_workers()) < self.num_aggregation_workers: raise ValueError( "The number of aggregation workers should not exceed the " "number of total evaluation workers ({} vs {})".format( self.num_aggregation_workers, len(self.workers.remote_workers()))) assigned_workers = collections.defaultdict(list) for i, ev in enumerate(self.workers.remote_workers()): assigned_workers[i % self.num_aggregation_workers].append(ev) self.aggregators = aggregators for i, agg in enumerate(self.aggregators): agg.init.remote(self.broadcasted_weights, assigned_workers[i], self.max_sample_requests_in_flight_per_worker, self.replay_proportion, self.replay_buffer_num_slots, self.train_batch_size, self.sample_batch_size) self.agg_tasks = TaskPool() for agg in self.aggregators: agg.set_weights.remote(self.broadcasted_weights) self.agg_tasks.add(agg, agg.get_train_batches.remote()) self.initialized = True
def _init(self, learning_starts=1000, buffer_size=10000, prioritized_replay=True, prioritized_replay_alpha=0.6, prioritized_replay_beta=0.4, prioritized_replay_eps=1e-6, train_batch_size=512, sample_batch_size=50, num_replay_buffer_shards=1, max_weight_sync_delay=400, clip_rewards=True, debug=False): self.debug = debug self.replay_starts = learning_starts self.prioritized_replay_beta = prioritized_replay_beta self.prioritized_replay_eps = prioritized_replay_eps self.train_batch_size = train_batch_size self.sample_batch_size = sample_batch_size self.max_weight_sync_delay = max_weight_sync_delay self.learner = LearnerThread(self.local_evaluator) self.learner.start() self.replay_actors = create_colocated(ReplayActor, [ num_replay_buffer_shards, learning_starts, buffer_size, train_batch_size, prioritized_replay_alpha, prioritized_replay_beta, prioritized_replay_eps, clip_rewards ], num_replay_buffer_shards) assert len(self.remote_evaluators) > 0 # Stats self.timers = { k: TimerStat() for k in [ "put_weights", "get_samples", "enqueue", "sample_processing", "replay_processing", "update_priorities", "train", "sample" ] } self.num_weight_syncs = 0 self.learning_started = False # Number of worker steps since the last weight update self.steps_since_update = {} # Otherwise kick of replay tasks for local gradient updates self.replay_tasks = TaskPool() for ra in self.replay_actors: for _ in range(REPLAY_QUEUE_DEPTH): self.replay_tasks.add(ra, ra.replay.remote()) # Kick off async background sampling self.sample_tasks = TaskPool() weights = self.local_evaluator.get_weights() for ev in self.remote_evaluators: ev.set_weights.remote(weights) self.steps_since_update[ev] = 0 for _ in range(SAMPLE_QUEUE_DEPTH): self.sample_tasks.add(ev, ev.sample.remote())
def _init(self, train_batch_size=512, sample_batch_size=50, debug=False): self.debug = debug self.learning_started = False self.train_batch_size = train_batch_size self.learner = LearnerThread(self.local_evaluator) self.learner.start() assert len(self.remote_evaluators) > 0 # Stats self.timers = { k: TimerStat() for k in ["put_weights", "enqueue", "sample_processing", "train", "sample"] } self.num_weight_syncs = 0 self.learning_started = False # Kick off async background sampling self.sample_tasks = TaskPool() weights = self.local_evaluator.get_weights() for ev in self.remote_evaluators: ev.set_weights.remote(weights) for _ in range(SAMPLE_QUEUE_DEPTH): self.sample_tasks.add(ev, ev.sample.remote()) self.batch_buffer = []
def __init__(self, initial_weights_obj_id, remote_evaluators, max_sample_requests_in_flight_per_worker, replay_proportion, replay_buffer_num_slots, train_batch_size, sample_batch_size): self.broadcasted_weights = initial_weights_obj_id self.remote_evaluators = remote_evaluators self.sample_batch_size = sample_batch_size self.train_batch_size = train_batch_size if replay_proportion: if replay_buffer_num_slots * sample_batch_size <= train_batch_size: raise ValueError( "Replay buffer size is too small to produce train, " "please increase replay_buffer_num_slots.", replay_buffer_num_slots, sample_batch_size, train_batch_size) # Kick off async background sampling self.sample_tasks = TaskPool() for ev in self.remote_evaluators: ev.set_weights.remote(self.broadcasted_weights) for _ in range(max_sample_requests_in_flight_per_worker): self.sample_tasks.add(ev, ev.sample.remote()) self.batch_buffer = [] self.replay_proportion = replay_proportion self.replay_buffer_num_slots = replay_buffer_num_slots self.replay_batches = [] self.num_sent_since_broadcast = 0 self.num_weight_syncs = 0 self.num_replayed = 0
def test_completed_prefetch_yieldsAllComplete(self, rayWaitMock): task1 = createMockWorkerAndObjectRef(1) task2 = createMockWorkerAndObjectRef(2) # Return the second task as complete and the first as pending rayWaitMock.return_value = ([2], [1]) pool = TaskPool() pool.add(*task1) pool.add(*task2) fetched = list(pool.completed_prefetch()) self.assertListEqual(fetched, [task2])
def __init__(self, initial_weights_obj_id, remote_workers, max_sample_requests_in_flight_per_worker, replay_proportion, replay_buffer_num_slots, train_batch_size, sample_batch_size, sync_sampling=False): """Initialize an aggregator. Arguments: initial_weights_obj_id (ObjectID): initial worker weights remote_workers (list): set of remote workers assigned to this agg max_sample_request_in_flight_per_worker (int): max queue size per worker replay_proportion (float): ratio of replay to sampled outputs replay_buffer_num_slots (int): max number of sample batches to store in the replay buffer train_batch_size (int): size of batches to learn on sample_batch_size (int): size of batches to sample from workers """ self.broadcasted_weights = initial_weights_obj_id self.remote_workers = remote_workers self.sample_batch_size = sample_batch_size self.train_batch_size = train_batch_size if replay_proportion: if replay_buffer_num_slots * sample_batch_size <= train_batch_size: raise ValueError( "Replay buffer size is too small to produce train, " "please increase replay_buffer_num_slots.", replay_buffer_num_slots, sample_batch_size, train_batch_size) self.batch_buffer = [] self.replay_proportion = replay_proportion self.replay_buffer_num_slots = replay_buffer_num_slots self.max_sample_requests_in_flight_per_worker = \ max_sample_requests_in_flight_per_worker self.started = False self.sample_tasks = TaskPool() self.replay_batches = [] self.replay_index = 0 self.num_sent_since_broadcast = 0 self.num_weight_syncs = 0 self.num_replayed = 0 self.sample_timesteps = 0 self.sync_sampling = sync_sampling
def __init__(self, initial_weights_obj_id, remote_workers, max_sample_requests_in_flight_per_worker, replay_proportion, replay_buffer_num_slots, train_batch_size, rollout_fragment_length): """Initialize an aggregator. Arguments: initial_weights_obj_id (ObjectID): initial worker weights remote_workers (list): set of remote workers assigned to this agg max_sample_request_in_flight_per_worker (int): max queue size per worker replay_proportion (float): ratio of replay to sampled outputs replay_buffer_num_slots (int): max number of sample batches to store in the replay buffer train_batch_size (int): size of batches to learn on rollout_fragment_length (int): size of batches to sample from workers. """ self.broadcasted_weights = initial_weights_obj_id self.remote_workers = remote_workers self.rollout_fragment_length = rollout_fragment_length self.train_batch_size = train_batch_size if replay_proportion: if (replay_buffer_num_slots * rollout_fragment_length <= train_batch_size): raise ValueError( "Replay buffer size is too small to produce train, " "please increase replay_buffer_num_slots.", replay_buffer_num_slots, rollout_fragment_length, train_batch_size) # Kick off async background sampling self.sample_tasks = TaskPool() for ev in self.remote_workers: ev.set_weights.remote(self.broadcasted_weights) for _ in range(max_sample_requests_in_flight_per_worker): self.sample_tasks.add(ev, ev.sample.remote()) self.batch_buffer = [] self.replay_proportion = replay_proportion self.replay_buffer_num_slots = replay_buffer_num_slots self.replay_batches = [] self.replay_index = 0 self.num_sent_since_broadcast = 0 self.num_weight_syncs = 0 self.num_replayed = 0
def test_reset_workers_pendingFetchesFromFailedWorkersRemoved(self, rayWaitMock): pool = TaskPool() # We need to hold onto the tasks for this test so that we can fail a # specific worker tasks = [] for i in range(10): task = createMockWorkerAndObjectRef(i) pool.add(*task) tasks.append(task) # Simulate only some of the work being complete and fetch a couple of # tasks in order to fill the fetching queue rayWaitMock.return_value = ([0, 1, 2, 3, 4, 5], [6, 7, 8, 9]) fetched = [pair[1] for pair in pool.completed_prefetch(max_yield=2)] # As we still have some pending tasks, we need to update the # completion states to remove the completed tasks rayWaitMock.return_value = ([], [6, 7, 8, 9]) pool.reset_workers( [ tasks[0][0], tasks[1][0], tasks[2][0], tasks[3][0], # OH NO! WORKER 4 HAS CRASHED! tasks[5][0], tasks[6][0], tasks[7][0], tasks[8][0], tasks[9][0], ] ) # Fetch the remaining tasks which should already be in the _fetching # queue fetched = [pair[1] for pair in pool.completed_prefetch()] self.assertListEqual(fetched, [2, 3, 5])
def test_completed_prefetch_yieldsAllCompleteUpToDefaultLimit(self, rayWaitMock): # Load the pool with 1000 tasks, mock them all as complete and then # check that the first call to completed_prefetch only yields 999 # items and the second call yields the final one pool = TaskPool() for i in range(1000): task = createMockWorkerAndObjectRef(i) pool.add(*task) rayWaitMock.return_value = (list(range(1000)), []) # For this test, we're only checking the object refs fetched = [pair[1] for pair in pool.completed_prefetch()] self.assertListEqual(fetched, list(range(999))) # Finally, check the next iteration returns the final taks fetched = [pair[1] for pair in pool.completed_prefetch()] self.assertListEqual(fetched, [999])
def test_completed_prefetch_yieldsAllCompleteUpToSpecifiedLimit(self, rayWaitMock): # Load the pool with 1000 tasks, mock them all as complete and then # check that the first call to completed_prefetch only yield 999 items # and the second call yields the final one pool = TaskPool() for i in range(1000): task = createMockWorkerAndObjectRef(i) pool.add(*task) rayWaitMock.return_value = (list(range(1000)), []) # Verify that only the first 500 tasks are returned, this should leave # some tasks in the _fetching deque for later fetched = [pair[1] for pair in pool.completed_prefetch(max_yield=500)] self.assertListEqual(fetched, list(range(500))) # Finally, check the next iteration returns the remaining tasks fetched = [pair[1] for pair in pool.completed_prefetch()] self.assertListEqual(fetched, list(range(500, 1000)))
def test_completed_prefetch_yieldsRemainingIfIterationStops(self, rayWaitMock): # Test for issue #7106 # In versions of Ray up to 0.8.1, if the pre-fetch generator failed to # run to completion, then the TaskPool would fail to clear up already # fetched tasks resulting in stale object refs being returned pool = TaskPool() for i in range(10): task = createMockWorkerAndObjectRef(i) pool.add(*task) rayWaitMock.return_value = (list(range(10)), []) # This should fetch just the first item in the list try: for _ in pool.completed_prefetch(): # Simulate a worker failure returned by ray.get() raise ray.exceptions.RayError except ray.exceptions.RayError: pass # This fetch should return the remaining pre-fetched tasks fetched = [pair[1] for pair in pool.completed_prefetch()] self.assertListEqual(fetched, list(range(1, 10)))
class AsyncSamplesOptimizer(PolicyOptimizer): """Main event loop of the IMPALA architecture. This class coordinates the data transfers between the learner thread and remote evaluators (IMPALA actors). """ def _init(self, train_batch_size=500, sample_batch_size=50, num_envs_per_worker=1, num_gpus=0, lr=0.0005, grad_clip=40, replay_buffer_num_slots=0, replay_proportion=0.0, num_parallel_data_loaders=1, max_sample_requests_in_flight_per_worker=2): self.learning_started = False self.train_batch_size = train_batch_size self.sample_batch_size = sample_batch_size if num_gpus > 1 or num_parallel_data_loaders > 1: logger.info( "Enabling multi-GPU mode, {} GPUs, {} parallel loaders".format( num_gpus, num_parallel_data_loaders)) if train_batch_size // max(1, num_gpus) % ( sample_batch_size // num_envs_per_worker) != 0: raise ValueError( "Sample batches must evenly divide across GPUs.") self.learner = TFMultiGPULearner( self.local_evaluator, lr=lr, num_gpus=num_gpus, train_batch_size=train_batch_size, grad_clip=grad_clip, num_parallel_data_loaders=num_parallel_data_loaders) else: self.learner = LearnerThread(self.local_evaluator) self.learner.start() assert len(self.remote_evaluators) > 0 # Stats self.timers = { k: TimerStat() for k in ["put_weights", "enqueue", "sample_processing", "train", "sample"] } self.num_weight_syncs = 0 self.num_replayed = 0 self.learning_started = False # Kick off async background sampling self.sample_tasks = TaskPool() weights = self.local_evaluator.get_weights() for ev in self.remote_evaluators: ev.set_weights.remote(weights) for _ in range(max_sample_requests_in_flight_per_worker): self.sample_tasks.add(ev, ev.sample.remote()) self.batch_buffer = [] if replay_proportion: assert replay_buffer_num_slots > 0 assert (replay_buffer_num_slots * sample_batch_size > train_batch_size) self.replay_proportion = replay_proportion self.replay_buffer_num_slots = replay_buffer_num_slots self.replay_batches = [] def step(self): assert self.learner.is_alive() start = time.time() sample_timesteps, train_timesteps = self._step() time_delta = time.time() - start self.timers["sample"].push(time_delta) self.timers["sample"].push_units_processed(sample_timesteps) if train_timesteps > 0: self.learning_started = True if self.learning_started: self.timers["train"].push(time_delta) self.timers["train"].push_units_processed(train_timesteps) self.num_steps_sampled += sample_timesteps self.num_steps_trained += train_timesteps def _augment_with_replay(self, sample_futures): def can_replay(): num_needed = int( np.ceil(self.train_batch_size / self.sample_batch_size)) return len(self.replay_batches) > num_needed for ev, sample_batch in sample_futures: sample_batch = ray.get(sample_batch) yield ev, sample_batch if can_replay(): f = self.replay_proportion while random.random() < f: f -= 1 replay_batch = random.choice(self.replay_batches) self.num_replayed += replay_batch.count yield None, replay_batch def _step(self): sample_timesteps, train_timesteps = 0, 0 weights = None with self.timers["sample_processing"]: for ev, sample_batch in self._augment_with_replay( self.sample_tasks.completed_prefetch()): self.batch_buffer.append(sample_batch) if sum(b.count for b in self.batch_buffer) >= self.train_batch_size: train_batch = self.batch_buffer[0].concat_samples( self.batch_buffer) with self.timers["enqueue"]: self.learner.inqueue.put(train_batch) self.batch_buffer = [] # If the batch was replayed, skip the update below. if ev is None: continue sample_timesteps += sample_batch.count # Put in replay buffer if enabled if self.replay_buffer_num_slots > 0: self.replay_batches.append(sample_batch) if len(self.replay_batches) > self.replay_buffer_num_slots: self.replay_batches.pop(0) # Note that it's important to pull new weights once # updated to avoid excessive correlation between actors if weights is None or self.learner.weights_updated: self.learner.weights_updated = False with self.timers["put_weights"]: weights = ray.put(self.local_evaluator.get_weights()) ev.set_weights.remote(weights) self.num_weight_syncs += 1 # Kick off another sample request self.sample_tasks.add(ev, ev.sample.remote()) while not self.learner.outqueue.empty(): count = self.learner.outqueue.get() train_timesteps += count return sample_timesteps, train_timesteps def stats(self): timing = { "{}_time_ms".format(k): round(1000 * self.timers[k].mean, 3) for k in self.timers } timing["learner_grad_time_ms"] = round( 1000 * self.learner.grad_timer.mean, 3) timing["learner_load_time_ms"] = round( 1000 * self.learner.load_timer.mean, 3) timing["learner_load_wait_time_ms"] = round( 1000 * self.learner.load_wait_timer.mean, 3) timing["learner_dequeue_time_ms"] = round( 1000 * self.learner.queue_timer.mean, 3) stats = { "sample_throughput": round(self.timers["sample"].mean_throughput, 3), "train_throughput": round(self.timers["train"].mean_throughput, 3), "num_weight_syncs": self.num_weight_syncs, "num_steps_replayed": self.num_replayed, "timing_breakdown": timing, "learner_queue": self.learner.learner_queue_size.stats(), } if self.learner.stats: stats["learner"] = self.learner.stats return dict(PolicyOptimizer.stats(self), **stats)
class AsyncSamplesOptimizer(PolicyOptimizer): """Main event loop of the IMPALA architecture. This class coordinates the data transfers between the learner thread and remote evaluators (IMPALA actors). """ @override(PolicyOptimizer) def _init(self, train_batch_size=500, sample_batch_size=50, num_envs_per_worker=1, num_gpus=0, lr=0.0005, replay_buffer_num_slots=0, replay_proportion=0.0, num_data_loader_buffers=1, max_sample_requests_in_flight_per_worker=2, broadcast_interval=1, num_sgd_iter=1, minibatch_buffer_size=1, learner_queue_size=16, _fake_gpus=False): self.train_batch_size = train_batch_size self.sample_batch_size = sample_batch_size self.broadcast_interval = broadcast_interval self._stats_start_time = time.time() self._last_stats_time = {} self._last_stats_sum = {} if num_gpus > 1 or num_data_loader_buffers > 1: logger.info( "Enabling multi-GPU mode, {} GPUs, {} parallel loaders".format( num_gpus, num_data_loader_buffers)) if num_data_loader_buffers < minibatch_buffer_size: raise ValueError( "In multi-gpu mode you must have at least as many " "parallel data loader buffers as minibatch buffers: " "{} vs {}".format(num_data_loader_buffers, minibatch_buffer_size)) self.learner = TFMultiGPULearner( self.local_evaluator, lr=lr, num_gpus=num_gpus, train_batch_size=train_batch_size, num_data_loader_buffers=num_data_loader_buffers, minibatch_buffer_size=minibatch_buffer_size, num_sgd_iter=num_sgd_iter, learner_queue_size=learner_queue_size, _fake_gpus=_fake_gpus) else: self.learner = LearnerThread(self.local_evaluator, minibatch_buffer_size, num_sgd_iter, learner_queue_size) self.learner.start() # Stats self._optimizer_step_timer = TimerStat() self.num_weight_syncs = 0 self.num_replayed = 0 self._stats_start_time = time.time() self._last_stats_time = {} self._last_stats_val = {} # Kick off async background sampling self.sample_tasks = TaskPool() weights = self.local_evaluator.get_weights() for ev in self.remote_evaluators: ev.set_weights.remote(weights) for _ in range(max_sample_requests_in_flight_per_worker): self.sample_tasks.add(ev, ev.sample.remote()) self.batch_buffer = [] if replay_proportion: if replay_buffer_num_slots * sample_batch_size <= train_batch_size: raise ValueError( "Replay buffer size is too small to produce train, " "please increase replay_buffer_num_slots.", replay_buffer_num_slots, sample_batch_size, train_batch_size) self.replay_proportion = replay_proportion self.replay_buffer_num_slots = replay_buffer_num_slots self.replay_batches = [] def add_stat_val(self, key, val): if key not in self._last_stats_sum: self._last_stats_sum[key] = 0 self._last_stats_time[key] = self._stats_start_time self._last_stats_sum[key] += val def get_mean_stats_and_reset(self): now = time.time() mean_stats = { key: round(val / (now - self._last_stats_time[key]), 3) for key, val in self._last_stats_sum.items() } for key in self._last_stats_sum.keys(): self._last_stats_sum[key] = 0 self._last_stats_time[key] = time.time() return mean_stats @override(PolicyOptimizer) def step(self): if len(self.remote_evaluators) == 0: raise ValueError("Config num_workers=0 means training will hang!") assert self.learner.is_alive() with self._optimizer_step_timer: sample_timesteps, train_timesteps = self._step() if sample_timesteps > 0: self.add_stat_val("sample_throughput", sample_timesteps) if train_timesteps > 0: self.add_stat_val("train_throughput", train_timesteps) self.num_steps_sampled += sample_timesteps self.num_steps_trained += train_timesteps @override(PolicyOptimizer) def stop(self): self.learner.stopped = True @override(PolicyOptimizer) def stats(self): def timer_to_ms(timer): return round(1000 * timer.mean, 3) timing = { "optimizer_step_time_ms": timer_to_ms(self._optimizer_step_timer), "learner_grad_time_ms": timer_to_ms(self.learner.grad_timer), "learner_load_time_ms": timer_to_ms(self.learner.load_timer), "learner_load_wait_time_ms": timer_to_ms( self.learner.load_wait_timer), "learner_dequeue_time_ms": timer_to_ms(self.learner.queue_timer), } stats = dict({ "num_weight_syncs": self.num_weight_syncs, "num_steps_replayed": self.num_replayed, "timing_breakdown": timing, "learner_queue": self.learner.learner_queue_size.stats(), }, **self.get_mean_stats_and_reset()) self._last_stats_val.clear() if self.learner.stats: stats["learner"] = self.learner.stats return dict(PolicyOptimizer.stats(self), **stats) def _step(self): sample_timesteps, train_timesteps = 0, 0 num_sent = 0 weights = None for ev, sample_batch in self._augment_with_replay( self.sample_tasks.completed_prefetch()): self.batch_buffer.append(sample_batch) if sum(b.count for b in self.batch_buffer) >= self.train_batch_size: train_batch = self.batch_buffer[0].concat_samples( self.batch_buffer) self.learner.inqueue.put(train_batch) self.batch_buffer = [] # If the batch was replayed, skip the update below. if ev is None: continue sample_timesteps += sample_batch.count # Put in replay buffer if enabled if self.replay_buffer_num_slots > 0: self.replay_batches.append(sample_batch) if len(self.replay_batches) > self.replay_buffer_num_slots: self.replay_batches.pop(0) # Note that it's important to pull new weights once # updated to avoid excessive correlation between actors if weights is None or (self.learner.weights_updated and num_sent >= self.broadcast_interval): self.learner.weights_updated = False weights = ray.put(self.local_evaluator.get_weights()) num_sent = 0 ev.set_weights.remote(weights) self.num_weight_syncs += 1 num_sent += 1 # Kick off another sample request self.sample_tasks.add(ev, ev.sample.remote()) while not self.learner.outqueue.empty(): count = self.learner.outqueue.get() train_timesteps += count return sample_timesteps, train_timesteps def _augment_with_replay(self, sample_futures): def can_replay(): num_needed = int( np.ceil(self.train_batch_size / self.sample_batch_size)) return len(self.replay_batches) > num_needed for ev, sample_batch in sample_futures: sample_batch = ray.get(sample_batch) yield ev, sample_batch if can_replay(): f = self.replay_proportion while random.random() < f: f -= 1 replay_batch = random.choice(self.replay_batches) self.num_replayed += replay_batch.count yield None, replay_batch
class AsyncSamplesOptimizer(PolicyOptimizer): """Main event loop of the IMPALA architecture. This class coordinates the data transfers between the learner thread and remote evaluators (IMPALA actors). """ def _init(self, train_batch_size=512, sample_batch_size=50, debug=False): self.debug = debug self.learning_started = False self.train_batch_size = train_batch_size self.learner = LearnerThread(self.local_evaluator) self.learner.start() assert len(self.remote_evaluators) > 0 # Stats self.timers = { k: TimerStat() for k in ["put_weights", "enqueue", "sample_processing", "train", "sample"] } self.num_weight_syncs = 0 self.learning_started = False # Kick off async background sampling self.sample_tasks = TaskPool() weights = self.local_evaluator.get_weights() for ev in self.remote_evaluators: ev.set_weights.remote(weights) for _ in range(SAMPLE_QUEUE_DEPTH): self.sample_tasks.add(ev, ev.sample.remote()) self.batch_buffer = [] def step(self): assert self.learner.is_alive() start = time.time() sample_timesteps, train_timesteps = self._step() time_delta = time.time() - start self.timers["sample"].push(time_delta) self.timers["sample"].push_units_processed(sample_timesteps) if train_timesteps > 0: self.learning_started = True if self.learning_started: self.timers["train"].push(time_delta) self.timers["train"].push_units_processed(train_timesteps) self.num_steps_sampled += sample_timesteps self.num_steps_trained += train_timesteps def _step(self): sample_timesteps, train_timesteps = 0, 0 weights = None with self.timers["sample_processing"]: for ev, sample_batch in self.sample_tasks.completed_prefetch(): sample_batch = ray.get(sample_batch) sample_timesteps += sample_batch.count self.batch_buffer.append(sample_batch) if sum(b.count for b in self.batch_buffer) >= self.train_batch_size: train_batch = self.batch_buffer[0].concat_samples( self.batch_buffer) with self.timers["enqueue"]: self.learner.inqueue.put((ev, train_batch)) self.batch_buffer = [] # Note that it's important to pull new weights once # updated to avoid excessive correlation between actors if weights is None or self.learner.weights_updated: self.learner.weights_updated = False with self.timers["put_weights"]: weights = ray.put(self.local_evaluator.get_weights()) ev.set_weights.remote(weights) self.num_weight_syncs += 1 # Kick off another sample request self.sample_tasks.add(ev, ev.sample.remote()) while not self.learner.outqueue.empty(): count = self.learner.outqueue.get() train_timesteps += count return sample_timesteps, train_timesteps def stats(self): timing = { "{}_time_ms".format(k): round(1000 * self.timers[k].mean, 3) for k in self.timers } timing["learner_grad_time_ms"] = round( 1000 * self.learner.grad_timer.mean, 3) timing["learner_dequeue_time_ms"] = round( 1000 * self.learner.queue_timer.mean, 3) stats = { "sample_throughput": round(self.timers["sample"].mean_throughput, 3), "train_throughput": round(self.timers["train"].mean_throughput, 3), "num_weight_syncs": self.num_weight_syncs, } debug_stats = { "timing_breakdown": timing, "pending_sample_tasks": self.sample_tasks.count, "learner_queue": self.learner.learner_queue_size.stats(), } if self.debug: stats.update(debug_stats) if self.learner.stats: stats["learner"] = self.learner.stats return dict(PolicyOptimizer.stats(self), **stats)
def _init(self, learning_starts=1000, buffer_size=10000, prioritized_replay=True, prioritized_replay_alpha=0.6, prioritized_replay_beta=0.4, prioritized_replay_eps=1e-6, train_batch_size=512, sample_batch_size=50, num_replay_buffer_shards=1, max_weight_sync_delay=400, debug=False, batch_replay=False): self.debug = debug self.batch_replay = batch_replay self.replay_starts = learning_starts self.prioritized_replay_beta = prioritized_replay_beta self.prioritized_replay_eps = prioritized_replay_eps self.max_weight_sync_delay = max_weight_sync_delay self.learner = LearnerThread(self.local_evaluator) self.learner.start() if self.batch_replay: replay_cls = BatchReplayActor else: replay_cls = ReplayActor self.replay_actors = create_colocated(replay_cls, [ num_replay_buffer_shards, learning_starts, buffer_size, train_batch_size, prioritized_replay_alpha, prioritized_replay_beta, prioritized_replay_eps, ], num_replay_buffer_shards) # Stats self.timers = { k: TimerStat() for k in [ "put_weights", "get_samples", "sample_processing", "replay_processing", "update_priorities", "train", "sample" ] } self.num_weight_syncs = 0 self.num_samples_dropped = 0 self.learning_started = False # Number of worker steps since the last weight update self.steps_since_update = {} # Otherwise kick of replay tasks for local gradient updates self.replay_tasks = TaskPool() for ra in self.replay_actors: for _ in range(REPLAY_QUEUE_DEPTH): self.replay_tasks.add(ra, ra.replay.remote()) # Kick off async background sampling self.sample_tasks = TaskPool() if self.remote_evaluators: self._set_evaluators(self.remote_evaluators)
def _init(self, train_batch_size=500, sample_batch_size=50, num_envs_per_worker=1, num_gpus=0, lr=0.0005, replay_buffer_num_slots=0, replay_proportion=0.0, num_data_loader_buffers=1, max_sample_requests_in_flight_per_worker=2, broadcast_interval=1, num_sgd_iter=1, minibatch_buffer_size=1, learner_queue_size=16, _fake_gpus=False): self.train_batch_size = train_batch_size self.sample_batch_size = sample_batch_size self.broadcast_interval = broadcast_interval self._stats_start_time = time.time() self._last_stats_time = {} self._last_stats_sum = {} if num_gpus > 1 or num_data_loader_buffers > 1: logger.info( "Enabling multi-GPU mode, {} GPUs, {} parallel loaders".format( num_gpus, num_data_loader_buffers)) if num_data_loader_buffers < minibatch_buffer_size: raise ValueError( "In multi-gpu mode you must have at least as many " "parallel data loader buffers as minibatch buffers: " "{} vs {}".format(num_data_loader_buffers, minibatch_buffer_size)) self.learner = TFMultiGPULearner( self.local_evaluator, lr=lr, num_gpus=num_gpus, train_batch_size=train_batch_size, num_data_loader_buffers=num_data_loader_buffers, minibatch_buffer_size=minibatch_buffer_size, num_sgd_iter=num_sgd_iter, learner_queue_size=learner_queue_size, _fake_gpus=_fake_gpus) else: self.learner = LearnerThread(self.local_evaluator, minibatch_buffer_size, num_sgd_iter, learner_queue_size) self.learner.start() # Stats self._optimizer_step_timer = TimerStat() self.num_weight_syncs = 0 self.num_replayed = 0 self._stats_start_time = time.time() self._last_stats_time = {} self._last_stats_val = {} # Kick off async background sampling self.sample_tasks = TaskPool() weights = self.local_evaluator.get_weights() for ev in self.remote_evaluators: ev.set_weights.remote(weights) for _ in range(max_sample_requests_in_flight_per_worker): self.sample_tasks.add(ev, ev.sample.remote()) self.batch_buffer = [] if replay_proportion: if replay_buffer_num_slots * sample_batch_size <= train_batch_size: raise ValueError( "Replay buffer size is too small to produce train, " "please increase replay_buffer_num_slots.", replay_buffer_num_slots, sample_batch_size, train_batch_size) self.replay_proportion = replay_proportion self.replay_buffer_num_slots = replay_buffer_num_slots self.replay_batches = []
class AggregationWorkerBase(object): """Aggregators should extend from this class.""" def __init__(self, initial_weights_obj_id, remote_evaluators, max_sample_requests_in_flight_per_worker, replay_proportion, replay_buffer_num_slots, train_batch_size, sample_batch_size): self.broadcasted_weights = initial_weights_obj_id self.remote_evaluators = remote_evaluators self.sample_batch_size = sample_batch_size self.train_batch_size = train_batch_size if replay_proportion: if replay_buffer_num_slots * sample_batch_size <= train_batch_size: raise ValueError( "Replay buffer size is too small to produce train, " "please increase replay_buffer_num_slots.", replay_buffer_num_slots, sample_batch_size, train_batch_size) # Kick off async background sampling self.sample_tasks = TaskPool() for ev in self.remote_evaluators: ev.set_weights.remote(self.broadcasted_weights) for _ in range(max_sample_requests_in_flight_per_worker): self.sample_tasks.add(ev, ev.sample.remote()) self.batch_buffer = [] self.replay_proportion = replay_proportion self.replay_buffer_num_slots = replay_buffer_num_slots self.replay_batches = [] self.num_sent_since_broadcast = 0 self.num_weight_syncs = 0 self.num_replayed = 0 @override(Aggregator) def iter_train_batches(self, max_yield=999): """Iterate over train batches. Arguments: max_yield (int): Max number of batches to iterate over in this cycle. Setting this avoids iter_train_batches returning too much data at once. """ for ev, sample_batch in self._augment_with_replay( self.sample_tasks.completed_prefetch(blocking_wait=True, max_yield=max_yield)): sample_batch.decompress_if_needed() self.batch_buffer.append(sample_batch) if sum(b.count for b in self.batch_buffer) >= self.train_batch_size: train_batch = self.batch_buffer[0].concat_samples( self.batch_buffer) yield train_batch self.batch_buffer = [] # If the batch was replayed, skip the update below. if ev is None: continue # Put in replay buffer if enabled if self.replay_buffer_num_slots > 0: self.replay_batches.append(sample_batch) if len(self.replay_batches) > self.replay_buffer_num_slots: self.replay_batches.pop(0) ev.set_weights.remote(self.broadcasted_weights) self.num_weight_syncs += 1 self.num_sent_since_broadcast += 1 # Kick off another sample request self.sample_tasks.add(ev, ev.sample.remote()) @override(Aggregator) def stats(self): return { "num_weight_syncs": self.num_weight_syncs, "num_steps_replayed": self.num_replayed, } @override(Aggregator) def reset(self, remote_evaluators): self.sample_tasks.reset_evaluators(remote_evaluators) def _augment_with_replay(self, sample_futures): def can_replay(): num_needed = int( np.ceil(self.train_batch_size / self.sample_batch_size)) return len(self.replay_batches) > num_needed for ev, sample_batch in sample_futures: sample_batch = ray.get(sample_batch) yield ev, sample_batch if can_replay(): f = self.replay_proportion while random.random() < f: f -= 1 replay_batch = random.choice(self.replay_batches) self.num_replayed += replay_batch.count yield None, replay_batch
class ApexOptimizer(Optimizer): def _init(self, learning_starts=1000, buffer_size=10000, prioritized_replay=True, prioritized_replay_alpha=0.6, prioritized_replay_beta=0.4, prioritized_replay_eps=1e-6, train_batch_size=512, sample_batch_size=50, num_replay_buffer_shards=1, max_weight_sync_delay=400): self.replay_starts = learning_starts self.prioritized_replay_beta = prioritized_replay_beta self.prioritized_replay_eps = prioritized_replay_eps self.train_batch_size = train_batch_size self.sample_batch_size = sample_batch_size self.max_weight_sync_delay = max_weight_sync_delay self.learner = GenericLearner(self.local_evaluator) self.learner.start() self.replay_actors = create_colocated(ReplayActor, [ num_replay_buffer_shards, learning_starts, buffer_size, train_batch_size, prioritized_replay_alpha, prioritized_replay_beta, prioritized_replay_eps ], num_replay_buffer_shards) assert len(self.remote_evaluators) > 0 # Stats self.timers = { k: TimerStat() for k in [ "put_weights", "get_samples", "enqueue", "sample_processing", "replay_processing", "update_priorities", "train", "sample" ] } self.meters = { k: WindowStat(k, 10) for k in [ "samples_per_loop", "replays_per_loop", "reprios_per_loop", "reweights_per_loop" ] } self.num_weight_syncs = 0 self.learning_started = False # Number of worker steps since the last weight update self.steps_since_update = {} # Otherwise kick of replay tasks for local gradient updates self.replay_tasks = TaskPool() for ra in self.replay_actors: for _ in range(REPLAY_QUEUE_DEPTH): self.replay_tasks.add(ra, ra.replay.remote()) # Kick off async background sampling self.sample_tasks = TaskPool() weights = self.local_evaluator.get_weights() for ev in self.remote_evaluators: ev.set_weights.remote(weights) self.steps_since_update[ev] = 0 for _ in range(SAMPLE_QUEUE_DEPTH): self.sample_tasks.add(ev, ev.sample.remote()) def step(self): start = time.time() sample_timesteps, train_timesteps = self._step() time_delta = time.time() - start self.timers["sample"].push(time_delta) self.timers["sample"].push_units_processed(sample_timesteps) if train_timesteps > 0: self.learning_started = True if self.learning_started: self.timers["train"].push(time_delta) self.timers["train"].push_units_processed(train_timesteps) self.num_steps_sampled += sample_timesteps self.num_steps_trained += train_timesteps def _step(self): sample_timesteps, train_timesteps = 0, 0 weights = None with self.timers["sample_processing"]: i = 0 num_weight_syncs = 0 for ev, sample_batch in self.sample_tasks.completed(): i += 1 sample_timesteps += self.sample_batch_size # Send the data to the replay buffer random.choice( self.replay_actors).add_batch.remote(sample_batch) # Update weights if needed self.steps_since_update[ev] += self.sample_batch_size if self.steps_since_update[ev] >= self.max_weight_sync_delay: if weights is None: with self.timers["put_weights"]: weights = ray.put( self.local_evaluator.get_weights()) ev.set_weights.remote(weights) self.num_weight_syncs += 1 num_weight_syncs += 1 self.steps_since_update[ev] = 0 # Kick off another sample request self.sample_tasks.add(ev, ev.sample.remote()) self.meters["samples_per_loop"].push(i) self.meters["reweights_per_loop"].push(num_weight_syncs) with self.timers["replay_processing"]: i = 0 for ra, replay in self.replay_tasks.completed(): i += 1 self.replay_tasks.add(ra, ra.replay.remote()) with self.timers["get_samples"]: samples = ray.get(replay) with self.timers["enqueue"]: self.learner.inqueue.put((ra, samples)) self.meters["replays_per_loop"].push(i) with self.timers["update_priorities"]: i = 0 while not self.learner.outqueue.empty(): i += 1 ra, replay, td_error = self.learner.outqueue.get() ra.update_priorities.remote(replay, td_error) train_timesteps += self.train_batch_size self.meters["reprios_per_loop"].push(i) return sample_timesteps, train_timesteps def stats(self): replay_stats = ray.get(self.replay_actors[0].stats.remote()) timing = { "{}_time_ms".format(k): round(1000 * self.timers[k].mean, 3) for k in self.timers } timing["learner_grad_time_ms"] = round( 1000 * self.learner.grad_timer.mean, 3) timing["learner_dequeue_time_ms"] = round( 1000 * self.learner.queue_timer.mean, 3) stats = { "replay_shard_0": replay_stats, "timing_breakdown": timing, "sample_throughput": round(self.timers["sample"].mean_throughput, 3), "train_throughput": round(self.timers["train"].mean_throughput, 3), "num_weight_syncs": self.num_weight_syncs, "pending_sample_tasks": self.sample_tasks.count, "pending_replay_tasks": self.replay_tasks.count, "learner_queue": self.learner.learner_queue_size.stats(), "samples": self.meters["samples_per_loop"].stats(), "replays": self.meters["replays_per_loop"].stats(), "reprios": self.meters["reprios_per_loop"].stats(), "reweights": self.meters["reweights_per_loop"].stats(), } return dict(Optimizer.stats(self), **stats)
def _init(self, learning_starts=1000, buffer_size=10000, prioritized_replay=True, prioritized_replay_alpha=0.6, prioritized_replay_beta=0.4, prioritized_replay_eps=1e-6, train_batch_size=512, sample_batch_size=50, num_replay_buffer_shards=1, max_weight_sync_delay=400, debug=False, batch_replay=False): self.debug = debug self.batch_replay = batch_replay self.replay_starts = learning_starts self.prioritized_replay_beta = prioritized_replay_beta self.prioritized_replay_eps = prioritized_replay_eps self.max_weight_sync_delay = max_weight_sync_delay self.learner = LearnerThread(self.local_evaluator) self.learner.start() if self.batch_replay: replay_cls = BatchReplayActor else: replay_cls = ReplayActor self.replay_actors = create_colocated(replay_cls, [ num_replay_buffer_shards, learning_starts, buffer_size, train_batch_size, prioritized_replay_alpha, prioritized_replay_beta, prioritized_replay_eps, ], num_replay_buffer_shards) # Stats self.timers = { k: TimerStat() for k in [ "put_weights", "get_samples", "sample_processing", "replay_processing", "update_priorities", "train", "sample" ] } self.num_weight_syncs = 0 self.num_samples_dropped = 0 self.learning_started = False # Number of worker steps since the last weight update self.steps_since_update = {} # Otherwise kick of replay tasks for local gradient updates self.replay_tasks = TaskPool() for ra in self.replay_actors: for _ in range(REPLAY_QUEUE_DEPTH): self.replay_tasks.add(ra, ra.replay.remote()) # Kick off async background sampling self.sample_tasks = TaskPool() if self.remote_evaluators: self._set_evaluators(self.remote_evaluators)
class TreeAggregator(Aggregator): """A hierarchical experiences aggregator. The given set of remote workers is divided into subsets and assigned to one of several aggregation workers. These aggregation workers collate experiences into batches of size `train_batch_size` and we collect them in this class when `iter_train_batches` is called. """ def __init__(self, workers, num_aggregation_workers, max_sample_requests_in_flight_per_worker=2, replay_proportion=0.0, replay_buffer_num_slots=0, train_batch_size=500, sample_batch_size=50, broadcast_interval=5): """Initialize a tree aggregator. Arguments: workers (WorkerSet): set of all workers num_aggregation_workers (int): number of intermediate actors to use for data aggregation max_sample_request_in_flight_per_worker (int): max queue size per worker replay_proportion (float): ratio of replay to sampled outputs replay_buffer_num_slots (int): max number of sample batches to store in the replay buffer train_batch_size (int): size of batches to learn on sample_batch_size (int): size of batches to sample from workers broadcast_interval (int): max number of workers to send the same set of weights to """ self.workers = workers self.num_aggregation_workers = num_aggregation_workers self.max_sample_requests_in_flight_per_worker = \ max_sample_requests_in_flight_per_worker self.replay_proportion = replay_proportion self.replay_buffer_num_slots = replay_buffer_num_slots self.sample_batch_size = sample_batch_size self.train_batch_size = train_batch_size self.broadcast_interval = broadcast_interval self.broadcasted_weights = ray.put( workers.local_worker().get_weights()) self.num_batches_processed = 0 self.num_broadcasts = 0 self.num_sent_since_broadcast = 0 self.initialized = False def init(self, aggregators): """Deferred init so that we can pass in previously created workers.""" assert len(aggregators) == self.num_aggregation_workers, aggregators if len(self.workers.remote_workers()) < self.num_aggregation_workers: raise ValueError( "The number of aggregation workers should not exceed the " "number of total evaluation workers ({} vs {})".format( self.num_aggregation_workers, len(self.workers.remote_workers()))) assigned_workers = collections.defaultdict(list) for i, ev in enumerate(self.workers.remote_workers()): assigned_workers[i % self.num_aggregation_workers].append(ev) self.aggregators = aggregators for i, agg in enumerate(self.aggregators): agg.init.remote(self.broadcasted_weights, assigned_workers[i], self.max_sample_requests_in_flight_per_worker, self.replay_proportion, self.replay_buffer_num_slots, self.train_batch_size, self.sample_batch_size) self.agg_tasks = TaskPool() for agg in self.aggregators: agg.set_weights.remote(self.broadcasted_weights) self.agg_tasks.add(agg, agg.get_train_batches.remote()) self.initialized = True @override(Aggregator) def iter_train_batches(self): assert self.initialized, "Must call init() before using this class." for agg, batches in self.agg_tasks.completed_prefetch(): for b in ray_get_and_free(batches): self.num_sent_since_broadcast += 1 yield b agg.set_weights.remote(self.broadcasted_weights) self.agg_tasks.add(agg, agg.get_train_batches.remote()) self.num_batches_processed += 1 @override(Aggregator) def broadcast_new_weights(self): self.broadcasted_weights = ray.put( self.workers.local_worker().get_weights()) self.num_sent_since_broadcast = 0 self.num_broadcasts += 1 @override(Aggregator) def should_broadcast(self): return self.num_sent_since_broadcast >= self.broadcast_interval @override(Aggregator) def stats(self): return { "num_broadcasts": self.num_broadcasts, "num_batches_processed": self.num_batches_processed, } @override(Aggregator) def reset(self, remote_workers): raise NotImplementedError("changing number of remote workers") @staticmethod def precreate_aggregators(n): return create_colocated(AggregationWorker, [], n)
class AsyncReplayOptimizer(PolicyOptimizer): """Main event loop of the Ape-X optimizer (async sampling with replay). This class coordinates the data transfers between the learner thread, remote workers (Ape-X actors), and replay buffer actors. This has two modes of operation: - normal replay: replays independent samples. - batch replay: simplified mode where entire sample batches are replayed. This supports RNNs, but not prioritization. This optimizer requires that rollout workers return an additional "td_error" array in the info return of compute_gradients(). This error term will be used for sample prioritization.""" def __init__(self, workers, learning_starts=1000, buffer_size=10000, prioritized_replay=True, prioritized_replay_alpha=0.6, prioritized_replay_beta=0.4, prioritized_replay_eps=1e-6, train_batch_size=512, rollout_fragment_length=50, num_replay_buffer_shards=1, max_weight_sync_delay=400, debug=False, batch_replay=False): """Initialize an async replay optimizer. Arguments: workers (WorkerSet): all workers learning_starts (int): wait until this many steps have been sampled before starting optimization. buffer_size (int): max size of the replay buffer prioritized_replay (bool): whether to enable prioritized replay prioritized_replay_alpha (float): replay alpha hyperparameter prioritized_replay_beta (float): replay beta hyperparameter prioritized_replay_eps (float): replay eps hyperparameter train_batch_size (int): size of batches to learn on rollout_fragment_length (int): size of batches to sample from workers. num_replay_buffer_shards (int): number of actors to use to store replay samples max_weight_sync_delay (int): update the weights of a rollout worker after collecting this number of timesteps from it debug (bool): return extra debug stats batch_replay (bool): replay entire sequential batches of experiences instead of sampling steps individually """ PolicyOptimizer.__init__(self, workers) self.debug = debug self.batch_replay = batch_replay self.replay_starts = learning_starts self.prioritized_replay_beta = prioritized_replay_beta self.prioritized_replay_eps = prioritized_replay_eps self.max_weight_sync_delay = max_weight_sync_delay self.learner = LearnerThread(self.workers.local_worker()) self.learner.start() if self.batch_replay: replay_cls = BatchReplayActor else: replay_cls = ReplayActor self.replay_actors = create_colocated(replay_cls, [ num_replay_buffer_shards, learning_starts, buffer_size, train_batch_size, prioritized_replay_alpha, prioritized_replay_beta, prioritized_replay_eps, ], num_replay_buffer_shards) # Stats self.timers = { k: TimerStat() for k in [ "put_weights", "get_samples", "sample_processing", "replay_processing", "update_priorities", "train", "sample" ] } self.num_weight_syncs = 0 self.num_samples_dropped = 0 self.learning_started = False # Number of worker steps since the last weight update self.steps_since_update = {} # Otherwise kick of replay tasks for local gradient updates self.replay_tasks = TaskPool() for ra in self.replay_actors: for _ in range(REPLAY_QUEUE_DEPTH): self.replay_tasks.add(ra, ra.replay.remote()) # Kick off async background sampling self.sample_tasks = TaskPool() if self.workers.remote_workers(): self._set_workers(self.workers.remote_workers()) @override(PolicyOptimizer) def step(self): assert self.learner.is_alive() assert len(self.workers.remote_workers()) > 0 start = time.time() sample_timesteps, train_timesteps = self._step() time_delta = time.time() - start self.timers["sample"].push(time_delta) self.timers["sample"].push_units_processed(sample_timesteps) if train_timesteps > 0: self.learning_started = True if self.learning_started: self.timers["train"].push(time_delta) self.timers["train"].push_units_processed(train_timesteps) self.num_steps_sampled += sample_timesteps self.num_steps_trained += train_timesteps @override(PolicyOptimizer) def stop(self): for r in self.replay_actors: r.__ray_terminate__.remote() self.learner.stopped = True @override(PolicyOptimizer) def reset(self, remote_workers): self.workers.reset(remote_workers) self.sample_tasks.reset_workers(remote_workers) @override(PolicyOptimizer) def stats(self): replay_stats = ray_get_and_free(self.replay_actors[0].stats.remote( self.debug)) timing = { "{}_time_ms".format(k): round(1000 * self.timers[k].mean, 3) for k in self.timers } timing["learner_grad_time_ms"] = round( 1000 * self.learner.grad_timer.mean, 3) timing["learner_dequeue_time_ms"] = round( 1000 * self.learner.queue_timer.mean, 3) stats = { "sample_throughput": round(self.timers["sample"].mean_throughput, 3), "train_throughput": round(self.timers["train"].mean_throughput, 3), "num_weight_syncs": self.num_weight_syncs, "num_samples_dropped": self.num_samples_dropped, "learner_queue": self.learner.learner_queue_size.stats(), "replay_shard_0": replay_stats, } debug_stats = { "timing_breakdown": timing, "pending_sample_tasks": self.sample_tasks.count, "pending_replay_tasks": self.replay_tasks.count, } if self.debug: stats.update(debug_stats) if self.learner.stats: stats["learner"] = self.learner.stats return dict(PolicyOptimizer.stats(self), **stats) # For https://github.com/ray-project/ray/issues/2541 only def _set_workers(self, remote_workers): self.workers.reset(remote_workers) weights = self.workers.local_worker().get_weights() for ev in self.workers.remote_workers(): ev.set_weights.remote(weights) self.steps_since_update[ev] = 0 for _ in range(SAMPLE_QUEUE_DEPTH): self.sample_tasks.add(ev, ev.sample_with_count.remote()) def _step(self): sample_timesteps, train_timesteps = 0, 0 weights = None with self.timers["sample_processing"]: completed = list(self.sample_tasks.completed()) # First try a batched ray.get(). ray_error = None try: counts = { i: v for i, v in enumerate( ray_get_and_free([c[1][1] for c in completed])) } # If there are failed workers, try to recover the still good ones # (via non-batched ray.get()) and store the first error (to raise # later). except RayError: counts = {} for i, c in enumerate(completed): try: counts[i] = ray_get_and_free(c[1][1]) except RayError as e: logger.exception( "Error in completed task: {}".format(e)) ray_error = ray_error if ray_error is not None else e for i, (ev, (sample_batch, count)) in enumerate(completed): # Skip failed tasks. if i not in counts: continue sample_timesteps += counts[i] # Send the data to the replay buffer random.choice( self.replay_actors).add_batch.remote(sample_batch) # Update weights if needed. self.steps_since_update[ev] += counts[i] if self.steps_since_update[ev] >= self.max_weight_sync_delay: # Note that it's important to pull new weights once # updated to avoid excessive correlation between actors. if weights is None or self.learner.weights_updated: self.learner.weights_updated = False with self.timers["put_weights"]: weights = ray.put( self.workers.local_worker().get_weights()) ev.set_weights.remote(weights) self.num_weight_syncs += 1 self.steps_since_update[ev] = 0 # Kick off another sample request. self.sample_tasks.add(ev, ev.sample_with_count.remote()) # Now that all still good tasks have been kicked off again, # we can throw the error. if ray_error: raise ray_error with self.timers["replay_processing"]: for ra, replay in self.replay_tasks.completed(): self.replay_tasks.add(ra, ra.replay.remote()) if self.learner.inqueue.full(): self.num_samples_dropped += 1 else: with self.timers["get_samples"]: samples = ray_get_and_free(replay) # Defensive copy against plasma crashes, see #2610 #3452 self.learner.inqueue.put((ra, samples and samples.copy())) with self.timers["update_priorities"]: while not self.learner.outqueue.empty(): ra, prio_dict, count = self.learner.outqueue.get() ra.update_priorities.remote(prio_dict) train_timesteps += count return sample_timesteps, train_timesteps
class AsyncSamplesOptimizer(PolicyOptimizer): """Main event loop of the IMPALA architecture. This class coordinates the data transfers between the learner thread and remote evaluators (IMPALA actors). """ @override(PolicyOptimizer) def _init(self, train_batch_size=500, sample_batch_size=50, num_envs_per_worker=1, num_gpus=0, lr=0.0005, replay_buffer_num_slots=0, replay_proportion=0.0, num_data_loader_buffers=1, max_sample_requests_in_flight_per_worker=2, broadcast_interval=1, num_sgd_iter=1, minibatch_buffer_size=1, learner_queue_size=16, _fake_gpus=False): self.train_batch_size = train_batch_size self.sample_batch_size = sample_batch_size self.broadcast_interval = broadcast_interval self._stats_start_time = time.time() self._last_stats_time = {} self._last_stats_sum = {} if num_gpus > 1 or num_data_loader_buffers > 1: logger.info( "Enabling multi-GPU mode, {} GPUs, {} parallel loaders".format( num_gpus, num_data_loader_buffers)) if num_data_loader_buffers < minibatch_buffer_size: raise ValueError( "In multi-gpu mode you must have at least as many " "parallel data loader buffers as minibatch buffers: " "{} vs {}".format(num_data_loader_buffers, minibatch_buffer_size)) self.learner = TFMultiGPULearner( self.local_evaluator, lr=lr, num_gpus=num_gpus, train_batch_size=train_batch_size, num_data_loader_buffers=num_data_loader_buffers, minibatch_buffer_size=minibatch_buffer_size, num_sgd_iter=num_sgd_iter, learner_queue_size=learner_queue_size, _fake_gpus=_fake_gpus) else: self.learner = LearnerThread(self.local_evaluator, minibatch_buffer_size, num_sgd_iter, learner_queue_size) self.learner.start() if len(self.remote_evaluators) == 0: logger.warning("Config num_workers=0 means training will hang!") # Stats self._optimizer_step_timer = TimerStat() self.num_weight_syncs = 0 self.num_replayed = 0 self._stats_start_time = time.time() self._last_stats_time = {} self._last_stats_val = {} # Kick off async background sampling self.sample_tasks = TaskPool() weights = self.local_evaluator.get_weights() for ev in self.remote_evaluators: ev.set_weights.remote(weights) for _ in range(max_sample_requests_in_flight_per_worker): self.sample_tasks.add(ev, ev.sample.remote()) self.batch_buffer = [] if replay_proportion: if replay_buffer_num_slots * sample_batch_size <= train_batch_size: raise ValueError( "Replay buffer size is too small to produce train, " "please increase replay_buffer_num_slots.", replay_buffer_num_slots, sample_batch_size, train_batch_size) self.replay_proportion = replay_proportion self.replay_buffer_num_slots = replay_buffer_num_slots self.replay_batches = [] def add_stat_val(self, key, val): if key not in self._last_stats_sum: self._last_stats_sum[key] = 0 self._last_stats_time[key] = self._stats_start_time self._last_stats_sum[key] += val def get_mean_stats_and_reset(self): now = time.time() mean_stats = { key: round(val / (now - self._last_stats_time[key]), 3) for key, val in self._last_stats_sum.items() } for key in self._last_stats_sum.keys(): self._last_stats_sum[key] = 0 self._last_stats_time[key] = time.time() return mean_stats @override(PolicyOptimizer) def step(self): assert self.learner.is_alive() with self._optimizer_step_timer: sample_timesteps, train_timesteps = self._step() if sample_timesteps > 0: self.add_stat_val("sample_throughput", sample_timesteps) if train_timesteps > 0: self.add_stat_val("train_throughput", train_timesteps) self.num_steps_sampled += sample_timesteps self.num_steps_trained += train_timesteps @override(PolicyOptimizer) def stop(self): self.learner.stopped = True @override(PolicyOptimizer) def stats(self): def timer_to_ms(timer): return round(1000 * timer.mean, 3) timing = { "optimizer_step_time_ms": timer_to_ms(self._optimizer_step_timer), "learner_grad_time_ms": timer_to_ms(self.learner.grad_timer), "learner_load_time_ms": timer_to_ms(self.learner.load_timer), "learner_load_wait_time_ms": timer_to_ms(self.learner.load_wait_timer), "learner_dequeue_time_ms": timer_to_ms(self.learner.queue_timer), } stats = dict( { "num_weight_syncs": self.num_weight_syncs, "num_steps_replayed": self.num_replayed, "timing_breakdown": timing, "learner_queue": self.learner.learner_queue_size.stats(), }, **self.get_mean_stats_and_reset()) self._last_stats_val.clear() if self.learner.stats: stats["learner"] = self.learner.stats return dict(PolicyOptimizer.stats(self), **stats) def _step(self): sample_timesteps, train_timesteps = 0, 0 num_sent = 0 weights = None for ev, sample_batch in self._augment_with_replay( self.sample_tasks.completed_prefetch()): self.batch_buffer.append(sample_batch) if sum(b.count for b in self.batch_buffer) >= self.train_batch_size: train_batch = self.batch_buffer[0].concat_samples( self.batch_buffer) self.learner.inqueue.put(train_batch) self.batch_buffer = [] # If the batch was replayed, skip the update below. if ev is None: continue sample_timesteps += sample_batch.count # Put in replay buffer if enabled if self.replay_buffer_num_slots > 0: self.replay_batches.append(sample_batch) if len(self.replay_batches) > self.replay_buffer_num_slots: self.replay_batches.pop(0) # Note that it's important to pull new weights once # updated to avoid excessive correlation between actors if weights is None or (self.learner.weights_updated and num_sent >= self.broadcast_interval): self.learner.weights_updated = False weights = ray.put(self.local_evaluator.get_weights()) num_sent = 0 ev.set_weights.remote(weights) self.num_weight_syncs += 1 num_sent += 1 # Kick off another sample request self.sample_tasks.add(ev, ev.sample.remote()) while not self.learner.outqueue.empty(): count = self.learner.outqueue.get() train_timesteps += count return sample_timesteps, train_timesteps def _augment_with_replay(self, sample_futures): def can_replay(): num_needed = int( np.ceil(self.train_batch_size / self.sample_batch_size)) return len(self.replay_batches) > num_needed for ev, sample_batch in sample_futures: sample_batch = ray.get(sample_batch) yield ev, sample_batch if can_replay(): f = self.replay_proportion while random.random() < f: f -= 1 replay_batch = random.choice(self.replay_batches) self.num_replayed += replay_batch.count yield None, replay_batch
class AggregationWorkerBase: """Aggregators should extend from this class.""" def __init__(self, initial_weights_obj_id, remote_workers, max_sample_requests_in_flight_per_worker, replay_proportion, replay_buffer_num_slots, train_batch_size, rollout_fragment_length): """Initialize an aggregator. Arguments: initial_weights_obj_id (ObjectID): initial worker weights remote_workers (list): set of remote workers assigned to this agg max_sample_request_in_flight_per_worker (int): max queue size per worker replay_proportion (float): ratio of replay to sampled outputs replay_buffer_num_slots (int): max number of sample batches to store in the replay buffer train_batch_size (int): size of batches to learn on rollout_fragment_length (int): size of batches to sample from workers. """ self.broadcasted_weights = initial_weights_obj_id self.remote_workers = remote_workers self.rollout_fragment_length = rollout_fragment_length self.train_batch_size = train_batch_size if replay_proportion: if (replay_buffer_num_slots * rollout_fragment_length <= train_batch_size): raise ValueError( "Replay buffer size is too small to produce train, " "please increase replay_buffer_num_slots.", replay_buffer_num_slots, rollout_fragment_length, train_batch_size) # Kick off async background sampling self.sample_tasks = TaskPool() for ev in self.remote_workers: ev.set_weights.remote(self.broadcasted_weights) for _ in range(max_sample_requests_in_flight_per_worker): self.sample_tasks.add(ev, ev.sample.remote()) self.batch_buffer = [] self.replay_proportion = replay_proportion self.replay_buffer_num_slots = replay_buffer_num_slots self.replay_batches = [] self.replay_index = 0 self.num_sent_since_broadcast = 0 self.num_weight_syncs = 0 self.num_replayed = 0 @override(Aggregator) def iter_train_batches(self, max_yield=999): """Iterate over train batches. Arguments: max_yield (int): Max number of batches to iterate over in this cycle. Setting this avoids iter_train_batches returning too much data at once. """ for ev, sample_batch in self._augment_with_replay( self.sample_tasks.completed_prefetch(blocking_wait=True, max_yield=max_yield)): sample_batch.decompress_if_needed() self.batch_buffer.append(sample_batch) if sum(b.count for b in self.batch_buffer) >= self.train_batch_size: if len(self.batch_buffer) == 1: # make a defensive copy to avoid sharing plasma memory # across multiple threads train_batch = self.batch_buffer[0].copy() else: train_batch = self.batch_buffer[0].concat_samples( self.batch_buffer) yield train_batch self.batch_buffer = [] # If the batch was replayed, skip the update below. if ev is None: continue # Put in replay buffer if enabled if self.replay_buffer_num_slots > 0: if len(self.replay_batches) < self.replay_buffer_num_slots: self.replay_batches.append(sample_batch) else: self.replay_batches[self.replay_index] = sample_batch self.replay_index += 1 self.replay_index %= self.replay_buffer_num_slots ev.set_weights.remote(self.broadcasted_weights) self.num_weight_syncs += 1 self.num_sent_since_broadcast += 1 # Kick off another sample request self.sample_tasks.add(ev, ev.sample.remote()) @override(Aggregator) def stats(self): return { "num_weight_syncs": self.num_weight_syncs, "num_steps_replayed": self.num_replayed, } @override(Aggregator) def reset(self, remote_workers): self.sample_tasks.reset_workers(remote_workers) def _augment_with_replay(self, sample_futures): def can_replay(): num_needed = int( np.ceil(self.train_batch_size / self.rollout_fragment_length)) return len(self.replay_batches) > num_needed for ev, sample_batch in sample_futures: sample_batch = ray.get(sample_batch) yield ev, sample_batch if can_replay(): f = self.replay_proportion while random.random() < f: f -= 1 replay_batch = random.choice(self.replay_batches) self.num_replayed += replay_batch.count yield None, replay_batch
class AsyncReplayOptimizer(PolicyOptimizer): """Main event loop of the Ape-X optimizer (async sampling with replay). This class coordinates the data transfers between the learner thread, remote evaluators (Ape-X actors), and replay buffer actors. This optimizer requires that policy evaluators return an additional "td_error" array in the info return of compute_gradients(). This error term will be used for sample prioritization.""" def _init(self, learning_starts=1000, buffer_size=10000, prioritized_replay=True, prioritized_replay_alpha=0.6, prioritized_replay_beta=0.4, prioritized_replay_eps=1e-6, train_batch_size=512, sample_batch_size=50, num_replay_buffer_shards=1, max_weight_sync_delay=400, debug=False): self.debug = debug self.replay_starts = learning_starts self.prioritized_replay_beta = prioritized_replay_beta self.prioritized_replay_eps = prioritized_replay_eps self.max_weight_sync_delay = max_weight_sync_delay self.learner = LearnerThread(self.local_evaluator) self.learner.start() self.replay_actors = create_colocated(ReplayActor, [ num_replay_buffer_shards, learning_starts, buffer_size, train_batch_size, prioritized_replay_alpha, prioritized_replay_beta, prioritized_replay_eps, ], num_replay_buffer_shards) # Stats self.timers = { k: TimerStat() for k in [ "put_weights", "get_samples", "enqueue", "sample_processing", "replay_processing", "update_priorities", "train", "sample" ] } self.num_weight_syncs = 0 self.learning_started = False # Number of worker steps since the last weight update self.steps_since_update = {} # Otherwise kick of replay tasks for local gradient updates self.replay_tasks = TaskPool() for ra in self.replay_actors: for _ in range(REPLAY_QUEUE_DEPTH): self.replay_tasks.add(ra, ra.replay.remote()) # Kick off async background sampling self.sample_tasks = TaskPool() if self.remote_evaluators: self.set_evaluators(self.remote_evaluators) # For https://github.com/ray-project/ray/issues/2541 only def set_evaluators(self, remote_evaluators): self.remote_evaluators = remote_evaluators weights = self.local_evaluator.get_weights() for ev in self.remote_evaluators: ev.set_weights.remote(weights) self.steps_since_update[ev] = 0 for _ in range(SAMPLE_QUEUE_DEPTH): self.sample_tasks.add(ev, ev.sample_with_count.remote()) def step(self): assert len(self.remote_evaluators) > 0 start = time.time() sample_timesteps, train_timesteps = self._step() time_delta = time.time() - start self.timers["sample"].push(time_delta) self.timers["sample"].push_units_processed(sample_timesteps) if train_timesteps > 0: self.learning_started = True if self.learning_started: self.timers["train"].push(time_delta) self.timers["train"].push_units_processed(train_timesteps) self.num_steps_sampled += sample_timesteps self.num_steps_trained += train_timesteps def _step(self): sample_timesteps, train_timesteps = 0, 0 weights = None with self.timers["sample_processing"]: completed = list(self.sample_tasks.completed()) counts = ray.get([c[1][1] for c in completed]) for i, (ev, (sample_batch, count)) in enumerate(completed): sample_timesteps += counts[i] # Send the data to the replay buffer random.choice( self.replay_actors).add_batch.remote(sample_batch) # Update weights if needed self.steps_since_update[ev] += counts[i] if self.steps_since_update[ev] >= self.max_weight_sync_delay: # Note that it's important to pull new weights once # updated to avoid excessive correlation between actors if weights is None or self.learner.weights_updated: self.learner.weights_updated = False with self.timers["put_weights"]: weights = ray.put( self.local_evaluator.get_weights()) ev.set_weights.remote(weights) self.num_weight_syncs += 1 self.steps_since_update[ev] = 0 # Kick off another sample request self.sample_tasks.add(ev, ev.sample_with_count.remote()) with self.timers["replay_processing"]: for ra, replay in self.replay_tasks.completed(): self.replay_tasks.add(ra, ra.replay.remote()) with self.timers["get_samples"]: samples = ray.get(replay) with self.timers["enqueue"]: self.learner.inqueue.put((ra, samples)) with self.timers["update_priorities"]: while not self.learner.outqueue.empty(): ra, replay, td_error, count = self.learner.outqueue.get() ra.update_priorities.remote(replay["batch_indexes"], td_error) train_timesteps += count return sample_timesteps, train_timesteps def stats(self): replay_stats = ray.get(self.replay_actors[0].stats.remote()) timing = { "{}_time_ms".format(k): round(1000 * self.timers[k].mean, 3) for k in self.timers } timing["learner_grad_time_ms"] = round( 1000 * self.learner.grad_timer.mean, 3) timing["learner_dequeue_time_ms"] = round( 1000 * self.learner.queue_timer.mean, 3) stats = { "sample_throughput": round(self.timers["sample"].mean_throughput, 3), "train_throughput": round(self.timers["train"].mean_throughput, 3), "num_weight_syncs": self.num_weight_syncs, } debug_stats = { "replay_shard_0": replay_stats, "timing_breakdown": timing, "pending_sample_tasks": self.sample_tasks.count, "pending_replay_tasks": self.replay_tasks.count, "learner_queue": self.learner.learner_queue_size.stats(), } if self.debug: stats.update(debug_stats) return dict(PolicyOptimizer.stats(self), **stats)
def _init(self, train_batch_size=500, sample_batch_size=50, num_envs_per_worker=1, num_gpus=0, lr=0.0005, grad_clip=40, replay_buffer_num_slots=0, replay_proportion=0.0, num_parallel_data_loaders=1, max_sample_requests_in_flight_per_worker=2): self.learning_started = False self.train_batch_size = train_batch_size self.sample_batch_size = sample_batch_size if num_gpus > 1 or num_parallel_data_loaders > 1: logger.info( "Enabling multi-GPU mode, {} GPUs, {} parallel loaders".format( num_gpus, num_parallel_data_loaders)) if train_batch_size // max(1, num_gpus) % ( sample_batch_size // num_envs_per_worker) != 0: raise ValueError( "Sample batches must evenly divide across GPUs.") self.learner = TFMultiGPULearner( self.local_evaluator, lr=lr, num_gpus=num_gpus, train_batch_size=train_batch_size, grad_clip=grad_clip, num_parallel_data_loaders=num_parallel_data_loaders) else: self.learner = LearnerThread(self.local_evaluator) self.learner.start() assert len(self.remote_evaluators) > 0 # Stats self.timers = { k: TimerStat() for k in ["put_weights", "enqueue", "sample_processing", "train", "sample"] } self.num_weight_syncs = 0 self.num_replayed = 0 self.learning_started = False # Kick off async background sampling self.sample_tasks = TaskPool() weights = self.local_evaluator.get_weights() for ev in self.remote_evaluators: ev.set_weights.remote(weights) for _ in range(max_sample_requests_in_flight_per_worker): self.sample_tasks.add(ev, ev.sample.remote()) self.batch_buffer = [] if replay_proportion: assert replay_buffer_num_slots > 0 assert (replay_buffer_num_slots * sample_batch_size > train_batch_size) self.replay_proportion = replay_proportion self.replay_buffer_num_slots = replay_buffer_num_slots self.replay_batches = []
def _init(self, train_batch_size=500, sample_batch_size=50, num_envs_per_worker=1, num_gpus=0, lr=0.0005, replay_buffer_num_slots=0, replay_proportion=0.0, num_data_loader_buffers=1, max_sample_requests_in_flight_per_worker=2, broadcast_interval=1, num_sgd_iter=1, minibatch_buffer_size=1, learner_queue_size=16, _fake_gpus=False): self.train_batch_size = train_batch_size self.sample_batch_size = sample_batch_size self.broadcast_interval = broadcast_interval self._stats_start_time = time.time() self._last_stats_time = {} self._last_stats_sum = {} if num_gpus > 1 or num_data_loader_buffers > 1: logger.info( "Enabling multi-GPU mode, {} GPUs, {} parallel loaders".format( num_gpus, num_data_loader_buffers)) if num_data_loader_buffers < minibatch_buffer_size: raise ValueError( "In multi-gpu mode you must have at least as many " "parallel data loader buffers as minibatch buffers: " "{} vs {}".format(num_data_loader_buffers, minibatch_buffer_size)) self.learner = TFMultiGPULearner( self.local_evaluator, lr=lr, num_gpus=num_gpus, train_batch_size=train_batch_size, num_data_loader_buffers=num_data_loader_buffers, minibatch_buffer_size=minibatch_buffer_size, num_sgd_iter=num_sgd_iter, learner_queue_size=learner_queue_size, _fake_gpus=_fake_gpus) else: self.learner = LearnerThread(self.local_evaluator, minibatch_buffer_size, num_sgd_iter, learner_queue_size) self.learner.start() if len(self.remote_evaluators) == 0: logger.warning("Config num_workers=0 means training will hang!") # Stats self._optimizer_step_timer = TimerStat() self.num_weight_syncs = 0 self.num_replayed = 0 self._stats_start_time = time.time() self._last_stats_time = {} self._last_stats_val = {} # Kick off async background sampling self.sample_tasks = TaskPool() weights = self.local_evaluator.get_weights() for ev in self.remote_evaluators: ev.set_weights.remote(weights) for _ in range(max_sample_requests_in_flight_per_worker): self.sample_tasks.add(ev, ev.sample.remote()) self.batch_buffer = [] if replay_proportion: if replay_buffer_num_slots * sample_batch_size <= train_batch_size: raise ValueError( "Replay buffer size is too small to produce train, " "please increase replay_buffer_num_slots.", replay_buffer_num_slots, sample_batch_size, train_batch_size) self.replay_proportion = replay_proportion self.replay_buffer_num_slots = replay_buffer_num_slots self.replay_batches = []
class AsyncReplayOptimizer(PolicyOptimizer): """Main event loop of the Ape-X optimizer (async sampling with replay). This class coordinates the data transfers between the learner thread, remote evaluators (Ape-X actors), and replay buffer actors. This has two modes of operation: - normal replay: replays independent samples. - batch replay: simplified mode where entire sample batches are replayed. This supports RNNs, but not prioritization. This optimizer requires that policy evaluators return an additional "td_error" array in the info return of compute_gradients(). This error term will be used for sample prioritization.""" @override(PolicyOptimizer) def _init(self, learning_starts=1000, buffer_size=10000, prioritized_replay=True, prioritized_replay_alpha=0.6, prioritized_replay_beta=0.4, prioritized_replay_eps=1e-6, train_batch_size=512, sample_batch_size=50, num_replay_buffer_shards=1, max_weight_sync_delay=400, debug=False, batch_replay=False): self.debug = debug self.batch_replay = batch_replay self.replay_starts = learning_starts self.prioritized_replay_beta = prioritized_replay_beta self.prioritized_replay_eps = prioritized_replay_eps self.max_weight_sync_delay = max_weight_sync_delay self.learner = LearnerThread(self.local_evaluator) self.learner.start() if self.batch_replay: replay_cls = BatchReplayActor else: replay_cls = ReplayActor self.replay_actors = create_colocated(replay_cls, [ num_replay_buffer_shards, learning_starts, buffer_size, train_batch_size, prioritized_replay_alpha, prioritized_replay_beta, prioritized_replay_eps, ], num_replay_buffer_shards) # Stats self.timers = { k: TimerStat() for k in [ "put_weights", "get_samples", "sample_processing", "replay_processing", "update_priorities", "train", "sample" ] } self.num_weight_syncs = 0 self.num_samples_dropped = 0 self.learning_started = False # Number of worker steps since the last weight update self.steps_since_update = {} # Otherwise kick of replay tasks for local gradient updates self.replay_tasks = TaskPool() for ra in self.replay_actors: for _ in range(REPLAY_QUEUE_DEPTH): self.replay_tasks.add(ra, ra.replay.remote()) # Kick off async background sampling self.sample_tasks = TaskPool() if self.remote_evaluators: self._set_evaluators(self.remote_evaluators) @override(PolicyOptimizer) def step(self): assert self.learner.is_alive() assert len(self.remote_evaluators) > 0 start = time.time() sample_timesteps, train_timesteps = self._step() time_delta = time.time() - start self.timers["sample"].push(time_delta) self.timers["sample"].push_units_processed(sample_timesteps) if train_timesteps > 0: self.learning_started = True if self.learning_started: self.timers["train"].push(time_delta) self.timers["train"].push_units_processed(train_timesteps) self.num_steps_sampled += sample_timesteps self.num_steps_trained += train_timesteps @override(PolicyOptimizer) def stop(self): for r in self.replay_actors: r.__ray_terminate__.remote() self.learner.stopped = True @override(PolicyOptimizer) def stats(self): replay_stats = ray.get(self.replay_actors[0].stats.remote(self.debug)) timing = { "{}_time_ms".format(k): round(1000 * self.timers[k].mean, 3) for k in self.timers } timing["learner_grad_time_ms"] = round( 1000 * self.learner.grad_timer.mean, 3) timing["learner_dequeue_time_ms"] = round( 1000 * self.learner.queue_timer.mean, 3) stats = { "sample_throughput": round(self.timers["sample"].mean_throughput, 3), "train_throughput": round(self.timers["train"].mean_throughput, 3), "num_weight_syncs": self.num_weight_syncs, "num_samples_dropped": self.num_samples_dropped, "learner_queue": self.learner.learner_queue_size.stats(), "replay_shard_0": replay_stats, } debug_stats = { "timing_breakdown": timing, "pending_sample_tasks": self.sample_tasks.count, "pending_replay_tasks": self.replay_tasks.count, } if self.debug: stats.update(debug_stats) if self.learner.stats: stats["learner"] = self.learner.stats return dict(PolicyOptimizer.stats(self), **stats) # For https://github.com/ray-project/ray/issues/2541 only def _set_evaluators(self, remote_evaluators): self.remote_evaluators = remote_evaluators weights = self.local_evaluator.get_weights() for ev in self.remote_evaluators: ev.set_weights.remote(weights) self.steps_since_update[ev] = 0 for _ in range(SAMPLE_QUEUE_DEPTH): self.sample_tasks.add(ev, ev.sample_with_count.remote()) def _step(self): sample_timesteps, train_timesteps = 0, 0 weights = None with self.timers["sample_processing"]: completed = list(self.sample_tasks.completed()) counts = ray.get([c[1][1] for c in completed]) for i, (ev, (sample_batch, count)) in enumerate(completed): sample_timesteps += counts[i] # Send the data to the replay buffer random.choice( self.replay_actors).add_batch.remote(sample_batch) # Update weights if needed self.steps_since_update[ev] += counts[i] if self.steps_since_update[ev] >= self.max_weight_sync_delay: # Note that it's important to pull new weights once # updated to avoid excessive correlation between actors if weights is None or self.learner.weights_updated: self.learner.weights_updated = False with self.timers["put_weights"]: weights = ray.put( self.local_evaluator.get_weights()) ev.set_weights.remote(weights) self.num_weight_syncs += 1 self.steps_since_update[ev] = 0 # Kick off another sample request self.sample_tasks.add(ev, ev.sample_with_count.remote()) with self.timers["replay_processing"]: for ra, replay in self.replay_tasks.completed(): self.replay_tasks.add(ra, ra.replay.remote()) if self.learner.inqueue.full(): self.num_samples_dropped += 1 else: with self.timers["get_samples"]: samples = ray.get(replay) # Defensive copy against plasma crashes, see #2610 #3452 self.learner.inqueue.put((ra, samples and samples.copy())) with self.timers["update_priorities"]: while not self.learner.outqueue.empty(): ra, prio_dict, count = self.learner.outqueue.get() ra.update_priorities.remote(prio_dict) train_timesteps += count return sample_timesteps, train_timesteps
def __init__(self, workers, learning_starts=1000, buffer_size=10000, prioritized_replay=True, prioritized_replay_alpha=0.6, prioritized_replay_beta=0.4, prioritized_replay_eps=1e-6, train_batch_size=512, rollout_fragment_length=50, num_replay_buffer_shards=1, max_weight_sync_delay=400, debug=False, batch_replay=False): """Initialize an async replay optimizer. Arguments: workers (WorkerSet): all workers learning_starts (int): wait until this many steps have been sampled before starting optimization. buffer_size (int): max size of the replay buffer prioritized_replay (bool): whether to enable prioritized replay prioritized_replay_alpha (float): replay alpha hyperparameter prioritized_replay_beta (float): replay beta hyperparameter prioritized_replay_eps (float): replay eps hyperparameter train_batch_size (int): size of batches to learn on rollout_fragment_length (int): size of batches to sample from workers. num_replay_buffer_shards (int): number of actors to use to store replay samples max_weight_sync_delay (int): update the weights of a rollout worker after collecting this number of timesteps from it debug (bool): return extra debug stats batch_replay (bool): replay entire sequential batches of experiences instead of sampling steps individually """ PolicyOptimizer.__init__(self, workers) self.debug = debug self.batch_replay = batch_replay self.replay_starts = learning_starts self.prioritized_replay_beta = prioritized_replay_beta self.prioritized_replay_eps = prioritized_replay_eps self.max_weight_sync_delay = max_weight_sync_delay self.learner = LearnerThread(self.workers.local_worker()) self.learner.start() if self.batch_replay: replay_cls = BatchReplayActor else: replay_cls = ReplayActor self.replay_actors = create_colocated(replay_cls, [ num_replay_buffer_shards, learning_starts, buffer_size, train_batch_size, prioritized_replay_alpha, prioritized_replay_beta, prioritized_replay_eps, ], num_replay_buffer_shards) # Stats self.timers = { k: TimerStat() for k in [ "put_weights", "get_samples", "sample_processing", "replay_processing", "update_priorities", "train", "sample" ] } self.num_weight_syncs = 0 self.num_samples_dropped = 0 self.learning_started = False # Number of worker steps since the last weight update self.steps_since_update = {} # Otherwise kick of replay tasks for local gradient updates self.replay_tasks = TaskPool() for ra in self.replay_actors: for _ in range(REPLAY_QUEUE_DEPTH): self.replay_tasks.add(ra, ra.replay.remote()) # Kick off async background sampling self.sample_tasks = TaskPool() if self.workers.remote_workers(): self._set_workers(self.workers.remote_workers())
class DRAggregatorBase: """Aggregators should extend from this class.""" def __init__(self, initial_weights_obj_id, remote_workers, max_sample_requests_in_flight_per_worker, replay_proportion, replay_buffer_num_slots, train_batch_size, sample_batch_size, sync_sampling=False): """Initialize an aggregator. Arguments: initial_weights_obj_id (ObjectID): initial worker weights remote_workers (list): set of remote workers assigned to this agg max_sample_request_in_flight_per_worker (int): max queue size per worker replay_proportion (float): ratio of replay to sampled outputs replay_buffer_num_slots (int): max number of sample batches to store in the replay buffer train_batch_size (int): size of batches to learn on sample_batch_size (int): size of batches to sample from workers """ self.broadcasted_weights = initial_weights_obj_id self.remote_workers = remote_workers self.sample_batch_size = sample_batch_size self.train_batch_size = train_batch_size if replay_proportion: if replay_buffer_num_slots * sample_batch_size <= train_batch_size: raise ValueError( "Replay buffer size is too small to produce train, " "please increase replay_buffer_num_slots.", replay_buffer_num_slots, sample_batch_size, train_batch_size) self.batch_buffer = [] self.replay_proportion = replay_proportion self.replay_buffer_num_slots = replay_buffer_num_slots self.max_sample_requests_in_flight_per_worker = \ max_sample_requests_in_flight_per_worker self.started = False self.sample_tasks = TaskPool() self.replay_batches = [] self.replay_index = 0 self.num_sent_since_broadcast = 0 self.num_weight_syncs = 0 self.num_replayed = 0 self.sample_timesteps = 0 self.sync_sampling = sync_sampling def start(self): # Kick off async background sampling for ev in self.remote_workers: ev.set_weights.remote(self.broadcasted_weights) for _ in range(self.max_sample_requests_in_flight_per_worker): self.sample_tasks.add(ev, ev.sample.remote()) self.started = True @override(Aggregator) def iter_train_batches(self, max_yield=999, _recursive_called=False): """Iterate over train batches. Arguments: force_yield_all (bool): Whether to return all batches until task pool is drained. max_yield (int): Max number of batches to iterate over in this cycle. Setting this avoids iter_train_batches returning too much data at once. """ assert self.started already_sent_out = False # ev is the rollout worker for ev, sample_batch in self._augment_with_replay( self.sample_tasks.completed_prefetch(blocking_wait=True, max_yield=max_yield)): sample_batch.decompress_if_needed() self.batch_buffer.append(sample_batch) if sum(b.count for b in self.batch_buffer) >= self.train_batch_size: if len(self.batch_buffer) == 1: # make a defensive copy to avoid sharing plasma memory # across multiple threads train_batch = self.batch_buffer[0].copy() else: train_batch = self.batch_buffer[0].concat_samples( self.batch_buffer) self.sample_timesteps += train_batch.count if self.sync_sampling: # If sync sampling is set, return batch and then stop. # You need to call start at the outside. # return [train_batch] already_sent_out = True yield train_batch else: yield train_batch self.batch_buffer = [] # If the batch was replayed, skip the update below. if ev is None: continue # Put in replay buffer if enabled if self.replay_buffer_num_slots > 0: if len(self.replay_batches) < self.replay_buffer_num_slots: self.replay_batches.append(sample_batch) else: self.replay_batches[self.replay_index] = sample_batch self.replay_index += 1 self.replay_index %= self.replay_buffer_num_slots if already_sent_out and self.sync_sampling: continue ev.set_weights.remote(self.broadcasted_weights) self.num_weight_syncs += 1 self.num_sent_since_broadcast += 1 # Kick off another sample request self.sample_tasks.add(ev, ev.sample.remote()) if self.sync_sampling and (not _recursive_called): while self.sample_tasks.count > 0: # A tricky way to force exhaust the task pool for train_batch in self.iter_train_batches( max_yield, _recursive_called=True): yield train_batch @override(Aggregator) def stats(self): return { "num_weight_syncs": self.num_weight_syncs, "num_steps_replayed": self.num_replayed, "sample_timesteps": self.sample_timesteps } @override(Aggregator) def reset(self, remote_workers): self.sample_tasks.reset_workers(remote_workers) def _augment_with_replay(self, sample_futures): def can_replay(): num_needed = int( np.ceil(self.train_batch_size / self.sample_batch_size)) return len(self.replay_batches) > num_needed for ev, sample_batch in sample_futures: sample_batch = ray_get_and_free(sample_batch) yield ev, sample_batch if can_replay(): f = self.replay_proportion while random.random() < f: f -= 1 replay_batch = random.choice(self.replay_batches) self.num_replayed += replay_batch.count yield None, replay_batch