def training_pipeline(workers, config): rollouts = ParallelRollouts(workers, mode="bulk_sync") if config["microbatch_size"]: num_microbatches = math.ceil(config["train_batch_size"] / config["microbatch_size"]) # In microbatch mode, we want to compute gradients on experience # microbatches, average a number of these microbatches, and then apply # the averaged gradient in one SGD step. This conserves GPU memory, # allowing for extremely large experience batches to be used. train_op = ( rollouts.combine( ConcatBatches( min_batch_size=config["microbatch_size"])).for_each( ComputeGradients(workers)) # (grads, info) .batch(num_microbatches) # List[(grads, info)] .for_each(AverageGradients()) # (avg_grads, info) .for_each(ApplyGradients(workers))) else: # In normal mode, we execute one SGD step per each train batch. train_op = rollouts \ .combine(ConcatBatches( min_batch_size=config["train_batch_size"])) \ .for_each(TrainOneStep(workers)) return StandardMetricsReporting(train_op, workers, config)
def execution_plan(workers, config): local_replay_buffer = ReplayBuffer(config["buffer_size"]) rollouts = ParallelRollouts(workers, mode="bulk_sync") # We execute the following steps concurrently: # (1) Generate rollouts and store them in our local replay buffer. Calling # next() on store_op drives this. store_op = rollouts.for_each(StoreToReplayBuffer(local_replay_buffer)) # (2) Read and train on experiences from the replay buffer. Every batch # returned from the LocalReplay() iterator is passed to TrainOneStep to # take a SGD step, and then we decide whether to update the target network. replay_op = LocalReplay(local_replay_buffer, config["train_batch_size"]) \ .for_each(TrainOneStep(workers)) \ .for_each(UpdateTargetNetwork( workers, config["target_network_update_freq"])) # Alternate deterministically between (1) and (2). train_op = Concurrently([store_op, replay_op], mode="round_robin") return StandardMetricsReporting(train_op, workers, config)
def execution_plan(workers, config): # Collects experiences in parallel from multiple RolloutWorker actors. rollouts = ParallelRollouts(workers, mode="bulk_sync") # Combine experiences batches until we hit `train_batch_size` in size. # Then, train the policy on those experiences and update the workers. train_op = rollouts \ .combine(ConcatBatches( min_batch_size=config["train_batch_size"])) \ .for_each(TrainOneStep(workers)) # Add on the standard episode reward, etc. metrics reporting. This returns # a LocalIterator[metrics_dict] representing metrics for each train step. return StandardMetricsReporting(train_op, workers, config)
def execution_plan(workers, config): # Create a number of replay buffer actors. # TODO(ekl) support batch replay options num_replay_buffer_shards = config["optimizer"]["num_replay_buffer_shards"] replay_actors = create_colocated(ReplayActor, [ num_replay_buffer_shards, config["learning_starts"], config["buffer_size"], config["train_batch_size"], config["prioritized_replay_alpha"], config["prioritized_replay_beta"], config["prioritized_replay_eps"], ], num_replay_buffer_shards) # Update experience priorities post learning. def update_prio_and_stats(item): actor, prio_dict, count = item actor.update_priorities.remote(prio_dict) metrics = LocalIterator.get_metrics() metrics.counters[STEPS_TRAINED_COUNTER] += count metrics.timers["learner_dequeue"] = learner_thread.queue_timer metrics.timers["learner_grad"] = learner_thread.grad_timer metrics.timers["learner_overall"] = learner_thread.overall_timer # Update worker weights as they finish generating experiences. class UpdateWorkerWeights: def __init__(self, learner_thread, workers, max_weight_sync_delay): self.learner_thread = learner_thread self.workers = workers self.steps_since_update = collections.defaultdict(int) self.max_weight_sync_delay = max_weight_sync_delay self.weights = None def __call__(self, item): actor, batch = item self.steps_since_update[actor] += batch.count if self.steps_since_update[actor] >= self.max_weight_sync_delay: # Note that it's important to pull new weights once # updated to avoid excessive correlation between actors. if self.weights is None or self.learner_thread.weights_updated: self.learner_thread.weights_updated = False self.weights = ray.put( self.workers.local_worker().get_weights()) actor.set_weights.remote(self.weights) self.steps_since_update[actor] = 0 # Update metrics. metrics = LocalIterator.get_metrics() metrics.counters["num_weight_syncs"] += 1 # Start the learner thread. learner_thread = LearnerThread(workers.local_worker()) learner_thread.start() # We execute the following steps concurrently: # (1) Generate rollouts and store them in our replay buffer actors. Update # the weights of the worker that generated the batch. rollouts = ParallelRollouts(workers, mode="async", async_queue_depth=2) store_op = rollouts \ .for_each(StoreToReplayActors(replay_actors)) \ .zip_with_source_actor() \ .for_each(UpdateWorkerWeights( learner_thread, workers, max_weight_sync_delay=config["optimizer"]["max_weight_sync_delay"]) ) # (2) Read experiences from the replay buffer actors and send to the # learner thread via its in-queue. replay_op = ParallelReplay(replay_actors, async_queue_depth=4) \ .zip_with_source_actor() \ .for_each(Enqueue(learner_thread.inqueue)) # (3) Get priorities back from learner thread and apply them to the # replay buffer actors. update_op = Dequeue( learner_thread.outqueue, check=learner_thread.is_alive) \ .for_each(update_prio_and_stats) \ .for_each(UpdateTargetNetwork( workers, config["target_network_update_freq"], by_steps_trained=True)) # Execute (1), (2), (3) asynchronously as fast as possible. merged_op = Concurrently([store_op, replay_op, update_op], mode="async") return StandardMetricsReporting(merged_op, workers, config)