def execution_plan(workers: WorkerSet, config: TrainerConfigDict, **kwargs) -> LocalIterator[dict]: assert ( "local_replay_buffer" in kwargs), "SlateQ execution plan requires a local replay buffer." rollouts = ParallelRollouts(workers, mode="bulk_sync") # We execute the following steps concurrently: # (1) Generate rollouts and store them in our local replay buffer. # Calling next() on store_op drives this. store_op = rollouts.for_each( StoreToReplayBuffer(local_buffer=kwargs["local_replay_buffer"])) # (2) Read and train on experiences from the replay buffer. Every batch # returned from the LocalReplay() iterator is passed to TrainOneStep to # take a SGD step. replay_op = (Replay( local_buffer=kwargs["local_replay_buffer"]).for_each( TrainOneStep(workers)).for_each( UpdateTargetNetwork(workers, config["target_network_update_freq"]))) if config["slateq_strategy"] != "RANDOM": # Alternate deterministically between (1) and (2). Only return the # output of (2) since training metrics are not available until (2) # runs. train_op = Concurrently( [store_op, replay_op], mode="round_robin", output_indexes=[1], round_robin_weights=calculate_round_robin_weights(config), ) else: # No training is needed for the RANDOM strategy. train_op = rollouts return StandardMetricsReporting(train_op, workers, config)
def execution_plan(workers: WorkerSet, config: TrainerConfigDict, **kwargs) -> LocalIterator[dict]: """Execution plan of the PPO algorithm. Defines the distributed dataflow. Args: workers (WorkerSet): The WorkerSet for training the Polic(y/ies) of the Trainer. config (TrainerConfigDict): The trainer's configuration dict. Returns: LocalIterator[dict]: The Policy class to use with PPOTrainer. If None, use `default_policy` provided in build_trainer(). """ assert len(kwargs) == 0, ( "PPO execution_plan does NOT take any additional parameters") rollouts = ParallelRollouts(workers, mode="bulk_sync") # Collect batches for the trainable policies. rollouts = rollouts.for_each( SelectExperiences(workers.trainable_policies())) # Concatenate the SampleBatches into one. rollouts = rollouts.combine( ConcatBatches( min_batch_size=config["train_batch_size"], count_steps_by=config["multiagent"]["count_steps_by"], )) # Standardize advantages. rollouts = rollouts.for_each(StandardizeFields(["advantages"])) # Perform one training step on the combined + standardized batch. if config["simple_optimizer"]: train_op = rollouts.for_each( TrainOneStep(workers, num_sgd_iter=config["num_sgd_iter"], sgd_minibatch_size=config["sgd_minibatch_size"])) else: train_op = rollouts.for_each( MultiGPUTrainOneStep( workers=workers, sgd_minibatch_size=config["sgd_minibatch_size"], num_sgd_iter=config["num_sgd_iter"], num_gpus=config["num_gpus"], shuffle_sequences=config["shuffle_sequences"], _fake_gpus=config["_fake_gpus"], framework=config.get("framework"))) # Update KL after each round of training. train_op = train_op.for_each(lambda t: t[1]).for_each(UpdateKL(workers)) # Warn about bad reward scales and return training metrics. return StandardMetricsReporting(train_op, workers, config) \ .for_each(lambda result: warn_about_bad_reward_scales(config, result))
def default_execution_plan(workers: WorkerSet, config: TrainerConfigDict): # Collects experiences in parallel from multiple RolloutWorker actors. rollouts = ParallelRollouts(workers, mode="bulk_sync") # Combine experiences batches until we hit `train_batch_size` in size. # Then, train the policy on those experiences and update the workers. train_op = rollouts \ .combine(ConcatBatches( min_batch_size=config["train_batch_size"])) \ .for_each(TrainOneStep(workers)) # Add on the standard episode reward, etc. metrics reporting. This returns # a LocalIterator[metrics_dict] representing metrics for each train step. return StandardMetricsReporting(train_op, workers, config)
def execution_plan(workers, config, **kwargs): assert len(kwargs) == 0, ( "Dreamer execution_plan does NOT take any additional parameters") # Special replay buffer for Dreamer agent. episode_buffer = EpisodicBuffer(length=config["batch_length"]) local_worker = workers.local_worker() # Prefill episode buffer with initial exploration (uniform sampling) while total_sampled_timesteps(local_worker) < config["prefill_timesteps"]: samples = local_worker.sample() episode_buffer.add(samples) batch_size = config["batch_size"] dreamer_train_iters = config["dreamer_train_iters"] act_repeat = config["action_repeat"] rollouts = ParallelRollouts(workers) rollouts = rollouts.for_each( DreamerIteration(local_worker, episode_buffer, dreamer_train_iters, batch_size, act_repeat)) return rollouts
def test_store_to_replay_local(self): buf = MultiAgentReplayBuffer( num_shards=1, learning_starts=200, capacity=1000, replay_batch_size=100, prioritized_replay_alpha=0.6, prioritized_replay_beta=0.4, prioritized_replay_eps=0.0001, ) assert buf.replay() is None workers = make_workers(0) a = ParallelRollouts(workers, mode="bulk_sync") b = a.for_each(StoreToReplayBuffer(local_buffer=buf)) next(b) assert buf.replay() is None # learning hasn't started yet next(b) assert buf.replay().count == 100 replay_op = Replay(local_buffer=buf) assert next(replay_op).count == 100
def execution_plan(workers: WorkerSet, config: TrainerConfigDict, **kwargs) -> LocalIterator[dict]: assert "local_replay_buffer" in kwargs, ( "GenericOffPolicy execution plan requires a local replay buffer.") local_replay_buffer = kwargs["local_replay_buffer"] rollouts = ParallelRollouts(workers, mode="bulk_sync") # (1) Generate rollouts and store them in our local replay buffer. store_op = rollouts.for_each( StoreToReplayBuffer(local_buffer=local_replay_buffer)) if config["simple_optimizer"]: train_step_op = TrainOneStep(workers) else: train_step_op = MultiGPUTrainOneStep( workers=workers, sgd_minibatch_size=config["train_batch_size"], num_sgd_iter=1, num_gpus=config["num_gpus"], shuffle_sequences=True, _fake_gpus=config["_fake_gpus"], framework=config.get("framework")) # (2) Read and train on experiences from the replay buffer. replay_op = Replay(local_buffer=local_replay_buffer) \ .for_each(train_step_op) \ .for_each(UpdateTargetNetwork( workers, config["target_network_update_freq"])) # Alternate deterministically between (1) and (2). train_op = Concurrently([store_op, replay_op], mode="round_robin", output_indexes=[1]) return StandardMetricsReporting(train_op, workers, config)
def execution_plan(workers: WorkerSet, config: TrainerConfigDict, **kwargs) -> LocalIterator[dict]: assert len(kwargs) == 0, ( "Alpha zero execution_plan does NOT take any additional parameters" ) rollouts = ParallelRollouts(workers, mode="bulk_sync") if config["simple_optimizer"]: train_op = rollouts.combine( ConcatBatches( min_batch_size=config["train_batch_size"], count_steps_by=config["multiagent"]["count_steps_by"], )).for_each( TrainOneStep(workers, num_sgd_iter=config["num_sgd_iter"])) else: replay_buffer = SimpleReplayBuffer(config["buffer_size"]) store_op = rollouts \ .for_each(StoreToReplayBuffer(local_buffer=replay_buffer)) replay_op = Replay(local_buffer=replay_buffer) \ .filter(WaitUntilTimestepsElapsed(config["learning_starts"])) \ .combine( ConcatBatches( min_batch_size=config["train_batch_size"], count_steps_by=config["multiagent"]["count_steps_by"], )) \ .for_each(TrainOneStep( workers, num_sgd_iter=config["num_sgd_iter"])) train_op = Concurrently([store_op, replay_op], mode="round_robin", output_indexes=[1]) return StandardMetricsReporting(train_op, workers, config)
def test_store_to_replay_actor(self): actor = ReplayActor.remote( num_shards=1, learning_starts=200, buffer_size=1000, replay_batch_size=100, prioritized_replay_alpha=0.6, prioritized_replay_beta=0.4, prioritized_replay_eps=0.0001, ) assert ray.get(actor.replay.remote()) is None workers = make_workers(0) a = ParallelRollouts(workers, mode="bulk_sync") b = a.for_each(StoreToReplayBuffer(actors=[actor])) next(b) assert ray.get( actor.replay.remote()) is None # learning hasn't started next(b) assert ray.get(actor.replay.remote()).count == 100 replay_op = Replay(actors=[actor]) assert next(replay_op).count == 100
def test_store_to_replay_actor(self): ReplayActor = ray.remote(num_cpus=0)(MultiAgentReplayBuffer) actor = ReplayActor.remote( num_shards=1, learning_starts=200, capacity=1000, prioritized_replay_alpha=0.6, prioritized_replay_beta=0.4, prioritized_replay_eps=0.0001, ) assert len(ray.get(actor.sample.remote(100))) == 0 workers = make_workers(0) a = ParallelRollouts(workers, mode="bulk_sync") b = a.for_each(StoreToReplayBuffer(actors=[actor])) next(b) assert len(ray.get( actor.sample.remote(100))) == 0 # learning hasn't started next(b) assert ray.get(actor.sample.remote(100)).count == 100 replay_op = Replay(actors=[actor], num_items_to_replay=100) assert next(replay_op).count == 100
def execution_plan(workers, config): # For A3C, compute policy gradients remotely on the rollout workers. # rollouts = ParallelRollouts(workers, mode="bulk_sync") grads = AsyncGradients(workers) # Apply the gradients as they arrive. We set update_all to False so that # only the worker sending the gradient is updated with new weights. #train_op = grads.for_each(ApplyGradients(workers, update_all=False)) print("_____") print(workers) temp1 = workers temp2 = workers rem1 = workers.remote_workers()[0:6] rem2 = workers.remote_workers()[6:11] temp1.reset(rem1) temp2.reset(rem2) rollouts1 = ParallelRollouts(temp1, mode="bulk_sync") rollouts2 = ParallelRollouts(temp2, mode="bulk_sync") train_step_op1 = TrainTFMultiGPU( workers=temp1, sgd_minibatch_size=config["train_batch_size"], num_sgd_iter=1, num_gpus=config["num_gpus"], shuffle_sequences=True, _fake_gpus=config["_fake_gpus"], framework=config.get("framework")) train_step_op2 = TrainTFMultiGPU( workers=temp2, sgd_minibatch_size=config["train_batch_size"], num_sgd_iter=1, num_gpus=config["num_gpus"], shuffle_sequences=True, _fake_gpus=config["_fake_gpus"], framework=config.get("framework")) train_op1 = rollouts1.combine( ConcatBatches( min_batch_size=config["train_batch_size"], count_steps_by=config["multiagent"][ "count_steps_by"])).for_each(train_step_op1) train_op2 = rollouts2.combine( ConcatBatches( min_batch_size=config["train_batch_size"], count_steps_by=config["multiagent"][ "count_steps_by"])).for_each(train_step_op2) #train_op = grads.for_each(ApplyGradients(workers, update_all=False)) return StandardMetricsReporting(train_op1, temp1, config).union(StandardMetricsReporting(train_op2, temp2, config))
def gather_experiences_directly(workers, config): rollouts = ParallelRollouts( workers, mode="async", num_async=config["max_sample_requests_in_flight_per_worker"]) # Augment with replay and concat to desired train batch size. train_batches = rollouts \ .for_each(lambda batch: batch.decompress_if_needed()) \ .for_each(MixInReplay( num_slots=config["replay_buffer_num_slots"], replay_proportion=config["replay_proportion"])) \ .flatten() \ .combine( ConcatBatches(min_batch_size=config["train_batch_size"])) return train_batches
def execution_plan(workers, config): rollouts = ParallelRollouts(workers, mode="bulk_sync") replay_buffer = SimpleReplayBuffer(config["replay_buffer_size"]) store_op = rollouts \ .for_each(StoreToReplayBuffer(local_buffer=replay_buffer)) replay_op = Replay(local_buffer=replay_buffer) \ .combine( ConcatBatches(min_batch_size=config["train_batch_size"])) \ .for_each(TrainOneStep(workers)) train_op = Concurrently([store_op, replay_op], mode="round_robin", output_indexes=[1]) return StandardMetricsReporting(train_op, workers, config)
def execution_plan(workers, config): rollouts = ParallelRollouts(workers, mode="bulk_sync") # Collect large batches of relevant experiences & standardize. rollouts = rollouts.for_each( SelectExperiences(workers.trainable_policies())) rollouts = rollouts.combine( ConcatBatches(min_batch_size=config["train_batch_size"])) rollouts = rollouts.for_each(StandardizeFields(["advantages"])) if config["simple_optimizer"]: train_op = rollouts.for_each( TrainOneStep(workers, num_sgd_iter=config["num_sgd_iter"], sgd_minibatch_size=config["sgd_minibatch_size"])) else: train_op = rollouts.for_each( TrainTFMultiGPU( workers, sgd_minibatch_size=config["sgd_minibatch_size"], num_sgd_iter=config["num_sgd_iter"], num_gpus=config["num_gpus"], rollout_fragment_length=config["rollout_fragment_length"], num_envs_per_worker=config["num_envs_per_worker"], train_batch_size=config["train_batch_size"], shuffle_sequences=config["shuffle_sequences"], _fake_gpus=config["_fake_gpus"])) # Callback to update the KL based on optimization info. def update_kl(item): _, fetches = item def update(pi, pi_id): if pi_id in fetches: pi.update_kl(fetches[pi_id]["kl"]) else: logger.warning("No data for {}, not updating kl".format(pi_id)) workers.local_worker().foreach_trainable_policy(update) # Update KL after each round of training. train_op = train_op.for_each(update_kl) return StandardMetricsReporting(train_op, workers, config) \ .for_each(lambda result: _warn_about_bad_reward_scales(config, result))
def execution_plan(workers: WorkerSet, config: TrainerConfigDict, **kwargs) -> LocalIterator[dict]: assert (len(kwargs) == 0 ), "PPO execution_plan does NOT take any additional parameters" rollouts = ParallelRollouts(workers, mode="bulk_sync") # Collect batches for the trainable policies. rollouts = rollouts.for_each( SelectExperiences(local_worker=workers.local_worker())) # Concatenate the SampleBatches into one. rollouts = rollouts.combine( ConcatBatches( min_batch_size=config["train_batch_size"], count_steps_by=config["multiagent"]["count_steps_by"], )) # Standardize advantages. rollouts = rollouts.for_each(StandardizeFields(["advantages"])) # Perform one training step on the combined + standardized batch. if config["simple_optimizer"]: train_op = rollouts.for_each( TrainOneStep( workers, num_sgd_iter=config["num_sgd_iter"], sgd_minibatch_size=config["sgd_minibatch_size"], )) else: train_op = rollouts.for_each( MultiGPUTrainOneStep( workers=workers, sgd_minibatch_size=config["sgd_minibatch_size"], num_sgd_iter=config["num_sgd_iter"], num_gpus=config["num_gpus"], _fake_gpus=config["_fake_gpus"], )) # Update KL after each round of training. train_op = train_op.for_each(lambda t: t[1]).for_each( UpdateKL(workers)) # Warn about bad reward scales and return training metrics. return StandardMetricsReporting(train_op, workers, config).for_each( lambda result: warn_about_bad_reward_scales(config, result))
def execution_plan(workers: WorkerSet, config: TrainerConfigDict, **kwargs) -> LocalIterator[dict]: """Execution plan of the MARWIL/BC algorithm. Defines the distributed dataflow. Args: workers (WorkerSet): The WorkerSet for training the Polic(y/ies) of the Trainer. config (TrainerConfigDict): The trainer's configuration dict. Returns: LocalIterator[dict]: A local iterator over training metrics. """ assert len(kwargs) == 0, ( "Marwill execution_plan does NOT take any additional parameters") rollouts = ParallelRollouts(workers, mode="bulk_sync") replay_buffer = LocalReplayBuffer( learning_starts=config["learning_starts"], capacity=config["replay_buffer_size"], replay_batch_size=config["train_batch_size"], replay_sequence_length=1, ) store_op = rollouts \ .for_each(StoreToReplayBuffer(local_buffer=replay_buffer)) replay_op = Replay(local_buffer=replay_buffer) \ .combine( ConcatBatches( min_batch_size=config["train_batch_size"], count_steps_by=config["multiagent"]["count_steps_by"], )) \ .for_each(TrainOneStep(workers)) train_op = Concurrently([store_op, replay_op], mode="round_robin", output_indexes=[1]) return StandardMetricsReporting(train_op, workers, config)
def execution_plan(workers, config): rollouts = ParallelRollouts(workers, mode="bulk_sync") replay_buffer = SimpleReplayBuffer(config["buffer_size"]) store_op = rollouts \ .for_each(StoreToReplayBuffer(local_buffer=replay_buffer)) train_op = Replay(local_buffer=replay_buffer) \ .combine( ConcatBatches( min_batch_size=config["train_batch_size"], count_steps_by=config["multiagent"]["count_steps_by"] )) \ .for_each(TrainOneStep(workers)) \ .for_each(UpdateTargetNetwork( workers, config["target_network_update_freq"])) merged_op = Concurrently([store_op, train_op], mode="round_robin", output_indexes=[1]) return StandardMetricsReporting(merged_op, workers, config)
def evaluate(trainer, num_episodes=20): ret_reward = [] ret_length = [] ret_success_rate = [] ret_out_rate = [] ret_crash_vehicle_rate = [] start = time.time() episode_count = 0 while episode_count < num_episodes: rollouts = ParallelRollouts(trainer.workers, mode="bulk_sync") batch = next(rollouts) episodes = batch.split_by_episode() ret_reward.extend([e["rewards"].sum() for e in episodes]) ret_length.extend([e.count for e in episodes]) ret_success_rate.extend( [e["infos"][-1][TerminationState.SUCCESS] for e in episodes]) ret_out_rate.extend( [e["infos"][-1][TerminationState.OUT_OF_ROAD] for e in episodes]) ret_crash_vehicle_rate.extend( [e["infos"][-1][TerminationState.CRASH_VEHICLE] for e in episodes]) episode_count += len(episodes) print("Finish {} episodes".format(episode_count)) ret = dict( reward=np.mean(ret_reward), length=np.mean(ret_length), success_rate=np.mean(ret_success_rate), out_rate=np.mean(ret_out_rate), crash_vehicle_rate=np.mean(ret_crash_vehicle_rate), episode_count=episode_count, time=time.time() - start, ) print("We collected {} episodes. Spent: {:.3f} s.\nResult: {}".format( episode_count, time.time() - start, {k: round(v, 3) for k, v in ret.items()})) return ret
def execution_plan(workers, config): rollouts = ParallelRollouts(workers, mode="bulk_sync") # Collect large batches of relevant experiences & standardize. rollouts = rollouts.for_each( SelectExperiences(workers.trainable_policies())) rollouts = rollouts.combine( ConcatBatches(min_batch_size=config["train_batch_size"])) rollouts = rollouts.for_each(StandardizeFields(["advantages"])) if config["simple_optimizer"]: train_op = rollouts.for_each( TrainOneStep(workers, num_sgd_iter=config["num_sgd_iter"], sgd_minibatch_size=config["sgd_minibatch_size"])) # Update KL after each round of training. train_op = train_op.for_each(lambda t: t[1]).for_each(UpdateKL(workers)) return StandardMetricsReporting(train_op, workers, config) \ .for_each(lambda result: warn_about_bad_reward_scales(config, result))
def execution_plan(workers: WorkerSet, config: TrainerConfigDict, **kwargs) -> LocalIterator[dict]: assert ( "local_replay_buffer" in kwargs), "DQN's execution plan requires a local replay buffer." # Assign to Trainer, so we can store the MultiAgentReplayBuffer's # data when we save checkpoints. local_replay_buffer = kwargs["local_replay_buffer"] rollouts = ParallelRollouts(workers, mode="bulk_sync") # We execute the following steps concurrently: # (1) Generate rollouts and store them in our local replay buffer. # Calling next() on store_op drives this. store_op = rollouts.for_each( StoreToReplayBuffer(local_buffer=local_replay_buffer)) def update_prio(item): samples, info_dict = item if config.get("prioritized_replay"): prio_dict = {} for policy_id, info in info_dict.items(): # TODO(sven): This is currently structured differently for # torch/tf. Clean up these results/info dicts across # policies (note: fixing this in torch_policy.py will # break e.g. DDPPO!). td_error = info.get( "td_error", info[LEARNER_STATS_KEY].get("td_error")) samples.policy_batches[policy_id].set_get_interceptor(None) batch_indices = samples.policy_batches[policy_id].get( "batch_indexes") # In case the buffer stores sequences, TD-error could # already be calculated per sequence chunk. if len(batch_indices) != len(td_error): T = local_replay_buffer.replay_sequence_length assert (len(batch_indices) > len(td_error) and len(batch_indices) % T == 0) batch_indices = batch_indices.reshape([-1, T])[:, 0] assert len(batch_indices) == len(td_error) prio_dict[policy_id] = (batch_indices, td_error) local_replay_buffer.update_priorities(prio_dict) return info_dict # (2) Read and train on experiences from the replay buffer. Every batch # returned from the LocalReplay() iterator is passed to TrainOneStep to # take a SGD step, and then we decide whether to update the target # network. post_fn = config.get("before_learn_on_batch") or (lambda b, *a: b) if config["simple_optimizer"]: train_step_op = TrainOneStep(workers) else: train_step_op = MultiGPUTrainOneStep( workers=workers, sgd_minibatch_size=config["train_batch_size"], num_sgd_iter=1, num_gpus=config["num_gpus"], _fake_gpus=config["_fake_gpus"], ) replay_op = (Replay(local_buffer=local_replay_buffer).for_each( lambda x: post_fn(x, workers, config)).for_each( train_step_op).for_each(update_prio).for_each( UpdateTargetNetwork(workers, config["target_network_update_freq"]))) # Alternate deterministically between (1) and (2). # Only return the output of (2) since training metrics are not # available until (2) runs. train_op = Concurrently( [store_op, replay_op], mode="round_robin", output_indexes=[1], round_robin_weights=calculate_rr_weights(config), ) return StandardMetricsReporting(train_op, workers, config)
def apex_execution_plan(workers: WorkerSet, config: dict) -> LocalIterator[dict]: # Create a number of replay buffer actors. num_replay_buffer_shards = config["optimizer"]["num_replay_buffer_shards"] replay_actors = create_colocated(ReplayActor, [ num_replay_buffer_shards, config["learning_starts"], config["buffer_size"], config["train_batch_size"], config["prioritized_replay_alpha"], config["prioritized_replay_beta"], config["prioritized_replay_eps"], config["multiagent"]["replay_mode"], config.get("replay_sequence_length", 1), ], num_replay_buffer_shards) # Start the learner thread. learner_thread = LearnerThread(workers.local_worker()) learner_thread.start() # Update experience priorities post learning. def update_prio_and_stats(item: Tuple["ActorHandle", dict, int]) -> None: actor, prio_dict, count = item actor.update_priorities.remote(prio_dict) metrics = _get_shared_metrics() # Manually update the steps trained counter since the learner thread # is executing outside the pipeline. metrics.counters[STEPS_TRAINED_COUNTER] += count metrics.timers["learner_dequeue"] = learner_thread.queue_timer metrics.timers["learner_grad"] = learner_thread.grad_timer metrics.timers["learner_overall"] = learner_thread.overall_timer # We execute the following steps concurrently: # (1) Generate rollouts and store them in our replay buffer actors. Update # the weights of the worker that generated the batch. rollouts = ParallelRollouts(workers, mode="async", num_async=2) store_op = rollouts \ .for_each(StoreToReplayBuffer(actors=replay_actors)) # Only need to update workers if there are remote workers. if workers.remote_workers(): store_op = store_op.zip_with_source_actor() \ .for_each(UpdateWorkerWeights( learner_thread, workers, max_weight_sync_delay=( config["optimizer"]["max_weight_sync_delay"]) )) # (2) Read experiences from the replay buffer actors and send to the # learner thread via its in-queue. post_fn = config.get("before_learn_on_batch") or (lambda b, *a: b) replay_op = Replay(actors=replay_actors, num_async=4) \ .for_each(lambda x: post_fn(x, workers, config)) \ .zip_with_source_actor() \ .for_each(Enqueue(learner_thread.inqueue)) # (3) Get priorities back from learner thread and apply them to the # replay buffer actors. update_op = Dequeue( learner_thread.outqueue, check=learner_thread.is_alive) \ .for_each(update_prio_and_stats) \ .for_each(UpdateTargetNetwork( workers, config["target_network_update_freq"], by_steps_trained=True)) if config["training_intensity"]: # Execute (1), (2) with a fixed intensity ratio. rr_weights = calculate_rr_weights(config) + ["*"] merged_op = Concurrently( [store_op, replay_op, update_op], mode="round_robin", output_indexes=[2], round_robin_weights=rr_weights) else: # Execute (1), (2), (3) asynchronously as fast as possible. Only output # items from (3) since metrics aren't available before then. merged_op = Concurrently( [store_op, replay_op, update_op], mode="async", output_indexes=[2]) # Add in extra replay and learner metrics to the training result. def add_apex_metrics(result: dict) -> dict: replay_stats = ray.get(replay_actors[0].stats.remote( config["optimizer"].get("debug"))) exploration_infos = workers.foreach_trainable_policy( lambda p, _: p.get_exploration_info()) result["info"].update({ "exploration_infos": exploration_infos, "learner_queue": learner_thread.learner_queue_size.stats(), "learner": copy.deepcopy(learner_thread.stats), "replay_shard_0": replay_stats, }) return result # Only report metrics from the workers with the lowest 1/3 of epsilons. selected_workers = workers.remote_workers()[ -len(workers.remote_workers()) // 3:] return StandardMetricsReporting( merged_op, workers, config, selected_workers=selected_workers).for_each(add_apex_metrics)
def execution_plan(workers, config): # Create a number of replay buffer actors. # TODO(ekl) support batch replay options num_replay_buffer_shards = config["optimizer"]["num_replay_buffer_shards"] replay_actors = create_colocated(ReplayActor, [ num_replay_buffer_shards, config["learning_starts"], config["buffer_size"], config["train_batch_size"], config["prioritized_replay_alpha"], config["prioritized_replay_beta"], config["prioritized_replay_eps"], ], num_replay_buffer_shards) # Update experience priorities post learning. def update_prio_and_stats(item): actor, prio_dict, count = item actor.update_priorities.remote(prio_dict) metrics = LocalIterator.get_metrics() # Manually update the steps trained counter since the learner thread # is executing outside the pipeline. metrics.counters[STEPS_TRAINED_COUNTER] += count metrics.timers["learner_dequeue"] = learner_thread.queue_timer metrics.timers["learner_grad"] = learner_thread.grad_timer metrics.timers["learner_overall"] = learner_thread.overall_timer # Update worker weights as they finish generating experiences. class UpdateWorkerWeights: def __init__(self, learner_thread, workers, max_weight_sync_delay): self.learner_thread = learner_thread self.workers = workers self.steps_since_update = collections.defaultdict(int) self.max_weight_sync_delay = max_weight_sync_delay self.weights = None def __call__(self, item): actor, batch = item self.steps_since_update[actor] += batch.count if self.steps_since_update[actor] >= self.max_weight_sync_delay: # Note that it's important to pull new weights once # updated to avoid excessive correlation between actors. if self.weights is None or self.learner_thread.weights_updated: self.learner_thread.weights_updated = False self.weights = ray.put( self.workers.local_worker().get_weights()) actor.set_weights.remote(self.weights) self.steps_since_update[actor] = 0 # Update metrics. metrics = LocalIterator.get_metrics() metrics.counters["num_weight_syncs"] += 1 # Start the learner thread. learner_thread = LearnerThread(workers.local_worker()) learner_thread.start() # We execute the following steps concurrently: # (1) Generate rollouts and store them in our replay buffer actors. Update # the weights of the worker that generated the batch. rollouts = ParallelRollouts(workers, mode="async", num_async=2) store_op = rollouts \ .for_each(StoreToReplayBuffer(actors=replay_actors)) \ .zip_with_source_actor() \ .for_each(UpdateWorkerWeights( learner_thread, workers, max_weight_sync_delay=config["optimizer"]["max_weight_sync_delay"]) ) # (2) Read experiences from the replay buffer actors and send to the # learner thread via its in-queue. replay_op = Replay(actors=replay_actors, num_async=4) \ .zip_with_source_actor() \ .for_each(Enqueue(learner_thread.inqueue)) # (3) Get priorities back from learner thread and apply them to the # replay buffer actors. update_op = Dequeue( learner_thread.outqueue, check=learner_thread.is_alive) \ .for_each(update_prio_and_stats) \ .for_each(UpdateTargetNetwork( workers, config["target_network_update_freq"], by_steps_trained=True)) # Execute (1), (2), (3) asynchronously as fast as possible. Only output # items from (3) since metrics aren't available before then. merged_op = Concurrently([store_op, replay_op, update_op], mode="async", output_indexes=[2]) # Add in extra replay and learner metrics to the training result. def add_apex_metrics(result): replay_stats = ray.get(replay_actors[0].stats.remote( config["optimizer"].get("debug"))) exploration_infos = workers.foreach_trainable_policy( lambda p, _: p.get_exploration_info()) result["info"].update({ "exploration_infos": exploration_infos, "learner_queue": learner_thread.learner_queue_size.stats(), "learner": copy.deepcopy(learner_thread.stats), "replay_shard_0": replay_stats, }) return result # Only report metrics from the workers with the lowest 1/3 of epsilons. selected_workers = workers.remote_workers( )[-len(workers.remote_workers()) // 3:] return StandardMetricsReporting( merged_op, workers, config, selected_workers=selected_workers).for_each(add_apex_metrics)
def execution_plan(workers: WorkerSet, config: TrainerConfigDict, **kwargs) -> LocalIterator[dict]: """Execution plan of the DD-PPO algorithm. Defines the distributed dataflow. Args: workers (WorkerSet): The WorkerSet for training the Polic(y/ies) of the Trainer. config (TrainerConfigDict): The trainer's configuration dict. Returns: LocalIterator[dict]: The Policy class to use with PGTrainer. If None, use `default_policy` provided in build_trainer(). """ assert len(kwargs) == 0, ( "DDPPO execution_plan does NOT take any additional parameters") rollouts = ParallelRollouts(workers, mode="raw") # Setup the distributed processes. if not workers.remote_workers(): raise ValueError("This optimizer requires >0 remote workers.") ip = ray.get(workers.remote_workers()[0].get_node_ip.remote()) port = ray.get(workers.remote_workers()[0].find_free_port.remote()) address = "tcp://{ip}:{port}".format(ip=ip, port=port) logger.info( "Creating torch process group with leader {}".format(address)) # Get setup tasks in order to throw errors on failure. ray.get([ worker.setup_torch_data_parallel.remote( url=address, world_rank=i, world_size=len(workers.remote_workers()), backend=config["torch_distributed_backend"]) for i, worker in enumerate(workers.remote_workers()) ]) logger.info("Torch process group init completed") # This function is applied remotely on each rollout worker. def train_torch_distributed_allreduce(batch): expected_batch_size = (config["rollout_fragment_length"] * config["num_envs_per_worker"]) this_worker = get_global_worker() assert batch.count == expected_batch_size, \ ("Batch size possibly out of sync between workers, expected:", expected_batch_size, "got:", batch.count) logger.info("Executing distributed minibatch SGD " "with epoch size {}, minibatch size {}".format( batch.count, config["sgd_minibatch_size"])) info = do_minibatch_sgd(batch, this_worker.policy_map, this_worker, config["num_sgd_iter"], config["sgd_minibatch_size"], ["advantages"]) return info, batch.count # Broadcast the local set of global vars. def update_worker_global_vars(item): global_vars = _get_global_vars() for w in workers.remote_workers(): w.set_global_vars.remote(global_vars) return item # Have to manually record stats since we are using "raw" rollouts mode. class RecordStats: def _on_fetch_start(self): self.fetch_start_time = time.perf_counter() def __call__(self, items): for item in items: info, count = item metrics = _get_shared_metrics() metrics.counters[STEPS_TRAINED_THIS_ITER_COUNTER] = count metrics.counters[STEPS_SAMPLED_COUNTER] += count metrics.counters[STEPS_TRAINED_COUNTER] += count metrics.info[LEARNER_INFO] = info # Since SGD happens remotely, the time delay between fetch and # completion is approximately the SGD step time. metrics.timers[LEARN_ON_BATCH_TIMER].push( time.perf_counter() - self.fetch_start_time) train_op = ( rollouts.for_each(train_torch_distributed_allreduce) # allreduce .batch_across_shards() # List[(grad_info, count)] .for_each(RecordStats())) train_op = train_op.for_each(update_worker_global_vars) # Sync down the weights. As with the sync up, this is not really # needed unless the user is reading the local weights. if config["keep_local_weights_in_sync"]: def download_weights(item): workers.local_worker().set_weights( ray.get(workers.remote_workers()[0].get_weights.remote())) return item train_op = train_op.for_each(download_weights) # In debug mode, check the allreduce successfully synced the weights. if logger.isEnabledFor(logging.DEBUG): def check_sync(item): weights = ray.get( [w.get_weights.remote() for w in workers.remote_workers()]) sums = [] for w in weights: acc = 0 for p in w.values(): for k, v in p.items(): acc += v.sum() sums.append(float(acc)) logger.debug("The worker weight sums are {}".format(sums)) assert len(set(sums)) == 1, sums train_op = train_op.for_each(check_sync) return StandardMetricsReporting(train_op, workers, config)
def custom_training_workflow_ppo_ddpg(workers: WorkerSet, config: dict): local_replay_buffer = LocalReplayBuffer(num_shards=1, learning_starts=1000, buffer_size=50000, replay_batch_size=64) def add_ppo_metrics(batch): print("PPO policy learning on samples from", batch.policy_batches.keys(), "env steps", batch.env_steps(), "agent steps", batch.env_steps()) metrics = _get_shared_metrics() metrics.counters["agent_steps_trained_PPO"] += batch.env_steps() return batch def add_ddpg_metrics(batch): print("DDPG policy learning on samples from", batch.policy_batches.keys(), "env steps", batch.env_steps(), "agent steps", batch.env_steps()) metrics = _get_shared_metrics() metrics.counters["agent_steps_trained_DDPG"] += batch.env_steps() return batch # Generate common experiences. rollouts = ParallelRollouts(workers, mode="bulk_sync") r1, r2 = rollouts.duplicate(n=2) # PPO sub-flow. ppo_train_op = r2.for_each(SelectExperiences(["PPO_policy"])) \ .combine(ConcatBatches( min_batch_size=200)) \ .for_each(add_ppo_metrics) \ .for_each(StandardizeFields(["advantages"])) \ .for_each(TrainOneStep( workers, policies=["PPO_policy"], num_sgd_iter=10, sgd_minibatch_size=128)) # DDPG sub-flow. ddpg_train_op = r2.for_each(SelectExperiences(["DDPG_policy"])) \ .combine(ConcatBatches( min_batch_size=200)) \ .for_each(add_ddpg_metrics) \ .for_each(StandardizeFields(["advantages"])) \ .for_each(TrainOneStep( workers, policies=["DDPG_policy"], num_sgd_iter=10, sgd_minibatch_size=128)) # , count_steps_by="env_steps")) \ # Combined training flow train_op = Concurrently([ppo_train_op, ddpg_train_op], mode="async", output_indexes=[1]) return StandardMetricsReporting(train_op, workers, config) # if __name__ == "__main__": # args = parser.parse_args() # assert not (args.torch and args.mixed_torch_tf),\ # "Use either --torch or --mixed-torch-tf, not both!" # ray.init() # # Simple environment with 4 independent cartpole entities # register_env("multi_agent_cartpole", # lambda _: MultiAgentCartPole({"num_agents": 4})) # single_env = gym.make("CartPole-v0") # obs_space = single_env.observation_space # act_space = single_env.action_space # # Note that since the trainer below does not include a default policy or # # policy configs, we have to explicitly set it in the multiagent config: # policies = { # "ppo_policy": (PPOTorchPolicy if args.torch or args.mixed_torch_tf else # PPOTFPolicy, obs_space, act_space, PPO_CONFIG), # "dqn_policy": (DQNTorchPolicy if args.torch else DQNTFPolicy, # obs_space, act_space, DQN_CONFIG), # } # def policy_mapping_fn(agent_id): # if agent_id % 2 == 0: # return "ppo_policy" # else: # return "dqn_policy" # MyTrainer = build_trainer( # name="PPO_DQN_MultiAgent", # default_policy=None, # execution_plan=custom_training_workflow) # config = { # "rollout_fragment_length": 50, # "num_workers": 0, # "env": "multi_agent_cartpole", # "multiagent": { # "policies": policies, # "policy_mapping_fn": policy_mapping_fn, # "policies_to_train": ["dqn_policy", "ppo_policy"], # }, # # Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0. # "num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")), # "framework": "torch" if args.torch else "tf", # "_use_trajectory_view_api": True, # } # stop = { # "training_iteration": args.stop_iters, # "timesteps_total": args.stop_timesteps, # "episode_reward_mean": args.stop_reward, # } # results = tune.run(MyTrainer, config=config, stop=stop) # if args.as_test: # check_learning_achieved(results, args.stop_reward) # ray.shutdown()
def execution_plan(workers: WorkerSet, config: TrainerConfigDict) -> LocalIterator[dict]: """Execution plan of the DQN algorithm. Defines the distributed dataflow. Args: workers (WorkerSet): The WorkerSet for training the Polic(y/ies) of the Trainer. config (TrainerConfigDict): The trainer's configuration dict. Returns: LocalIterator[dict]: A local iterator over training metrics. """ if config.get("prioritized_replay"): prio_args = { "prioritized_replay_alpha": config["prioritized_replay_alpha"], "prioritized_replay_beta": config["prioritized_replay_beta"], "prioritized_replay_eps": config["prioritized_replay_eps"], } else: prio_args = {} local_replay_buffer = LocalReplayBuffer( num_shards=1, learning_starts=config["learning_starts"], buffer_size=config["buffer_size"], replay_batch_size=config["train_batch_size"], replay_mode=config["multiagent"]["replay_mode"], replay_sequence_length=config.get("replay_sequence_length", 1), replay_burn_in=config.get("burn_in", 0), replay_zero_init_states=config.get("zero_init_states", True), **prio_args) rollouts = ParallelRollouts(workers, mode="bulk_sync") # We execute the following steps concurrently: # (1) Generate rollouts and store them in our local replay buffer. Calling # next() on store_op drives this. store_op = rollouts.for_each( StoreToReplayBuffer(local_buffer=local_replay_buffer)) def update_prio(item): samples, info_dict = item if config.get("prioritized_replay"): prio_dict = {} for policy_id, info in info_dict.items(): # TODO(sven): This is currently structured differently for # torch/tf. Clean up these results/info dicts across # policies (note: fixing this in torch_policy.py will # break e.g. DDPPO!). td_error = info.get("td_error", info[LEARNER_STATS_KEY].get("td_error")) samples.policy_batches[policy_id].set_get_interceptor(None) prio_dict[policy_id] = ( samples.policy_batches[policy_id].get("batch_indexes"), td_error) local_replay_buffer.update_priorities(prio_dict) return info_dict # (2) Read and train on experiences from the replay buffer. Every batch # returned from the LocalReplay() iterator is passed to TrainOneStep to # take a SGD step, and then we decide whether to update the target network. post_fn = config.get("before_learn_on_batch") or (lambda b, *a: b) if config["simple_optimizer"]: train_step_op = TrainOneStep(workers) else: train_step_op = TrainTFMultiGPU( workers=workers, sgd_minibatch_size=config["train_batch_size"], num_sgd_iter=1, num_gpus=config["num_gpus"], shuffle_sequences=True, _fake_gpus=config["_fake_gpus"], framework=config.get("framework")) replay_op = Replay(local_buffer=local_replay_buffer) \ .for_each(lambda x: post_fn(x, workers, config)) \ .for_each(train_step_op) \ .for_each(update_prio) \ .for_each(UpdateTargetNetwork( workers, config["target_network_update_freq"])) # Alternate deterministically between (1) and (2). Only return the output # of (2) since training metrics are not available until (2) runs. train_op = Concurrently([store_op, replay_op], mode="round_robin", output_indexes=[1], round_robin_weights=calculate_rr_weights(config)) return StandardMetricsReporting(train_op, workers, config)
def custom_training_workflow(workers: WorkerSet, config: dict): local_replay_buffer = MultiAgentReplayBuffer(num_shards=1, learning_starts=1000, capacity=50000, replay_batch_size=64) def add_ppo_metrics(batch): print( "PPO policy learning on samples from", batch.policy_batches.keys(), "env steps", batch.env_steps(), "agent steps", batch.env_steps(), ) metrics = _get_shared_metrics() metrics.counters["agent_steps_trained_PPO"] += batch.env_steps() return batch def add_dqn_metrics(batch): print( "DQN policy learning on samples from", batch.policy_batches.keys(), "env steps", batch.env_steps(), "agent steps", batch.env_steps(), ) metrics = _get_shared_metrics() metrics.counters["agent_steps_trained_DQN"] += batch.env_steps() return batch # Generate common experiences. rollouts = ParallelRollouts(workers, mode="bulk_sync") r1, r2 = rollouts.duplicate(n=2) # DQN sub-flow. dqn_store_op = r1.for_each(SelectExperiences(["dqn_policy"])).for_each( StoreToReplayBuffer(local_buffer=local_replay_buffer)) dqn_replay_op = (Replay( local_buffer=local_replay_buffer).for_each(add_dqn_metrics).for_each( TrainOneStep(workers, policies=["dqn_policy"])).for_each( UpdateTargetNetwork(workers, target_update_freq=500, policies=["dqn_policy"]))) dqn_train_op = Concurrently([dqn_store_op, dqn_replay_op], mode="round_robin", output_indexes=[1]) # PPO sub-flow. ppo_train_op = (r2.for_each(SelectExperiences(["ppo_policy"])).combine( ConcatBatches( min_batch_size=200, count_steps_by="env_steps")).for_each(add_ppo_metrics).for_each( StandardizeFields(["advantages"])).for_each( TrainOneStep( workers, policies=["ppo_policy"], num_sgd_iter=10, sgd_minibatch_size=128, ))) # Combined training flow train_op = Concurrently([ppo_train_op, dqn_train_op], mode="async", output_indexes=[1]) return StandardMetricsReporting(train_op, workers, config)
def execution_plan(workers, config): rollouts = ParallelRollouts(workers, mode="raw") # Setup the distributed processes. if not workers.remote_workers(): raise ValueError("This optimizer requires >0 remote workers.") ip = ray.get(workers.remote_workers()[0].get_node_ip.remote()) port = ray.get(workers.remote_workers()[0].find_free_port.remote()) address = "tcp://{ip}:{port}".format(ip=ip, port=port) logger.info("Creating torch process group with leader {}".format(address)) # Get setup tasks in order to throw errors on failure. ray.get([ worker.setup_torch_data_parallel.remote( address, i, len(workers.remote_workers()), backend="nccl") #address, i, len(workers.remote_workers()), backend="gloo") for i, worker in enumerate(workers.remote_workers()) ]) logger.info("Torch process group init completed") # This function is applied remotely on each rollout worker. def train_torch_distributed_allreduce(batch): expected_batch_size = ( config["rollout_fragment_length"] * config["num_envs_per_worker"]) this_worker = get_global_worker() assert batch.count == expected_batch_size, \ ("Batch size possibly out of sync between workers, expected:", expected_batch_size, "got:", batch.count) logger.info("Executing distributed minibatch SGD " "with epoch size {}, minibatch size {}".format( batch.count, config["sgd_minibatch_size"])) info = do_minibatch_sgd(batch, this_worker.policy_map, this_worker, config["num_sgd_iter"], config["sgd_minibatch_size"], ["advantages"]) return info, batch.count # Have to manually record stats since we are using "raw" rollouts mode. class RecordStats: def _on_fetch_start(self): self.fetch_start_time = time.perf_counter() def __call__(self, items): for item in items: info, count = item metrics = LocalIterator.get_metrics() metrics.counters[STEPS_SAMPLED_COUNTER] += count metrics.counters[STEPS_TRAINED_COUNTER] += count metrics.info[LEARNER_INFO] = info # Since SGD happens remotely, the time delay between fetch and # completion is approximately the SGD step time. metrics.timers[LEARN_ON_BATCH_TIMER].push(time.perf_counter() - self.fetch_start_time) train_op = ( rollouts.for_each(train_torch_distributed_allreduce) # allreduce .batch_across_shards() # List[(grad_info, count)] .for_each(RecordStats())) # Sync down the weights. As with the sync up, this is not really # needed unless the user is reading the local weights. if config["keep_local_weights_in_sync"]: def download_weights(item): workers.local_worker().set_weights( ray.get(workers.remote_workers()[0].get_weights.remote())) return item train_op = train_op.for_each(download_weights) # In debug mode, check the allreduce successfully synced the weights. if logger.isEnabledFor(logging.DEBUG): def check_sync(item): weights = ray.get( [w.get_weights.remote() for w in workers.remote_workers()]) sums = [] for w in weights: acc = 0 for p in w.values(): for k, v in p.items(): acc += v.sum() sums.append(float(acc)) logger.debug("The worker weight sums are {}".format(sums)) assert len(set(sums)) == 1, sums train_op = train_op.for_each(check_sync) return StandardMetricsReporting(train_op, workers, config)
def execution_plan(workers: WorkerSet, config: dict, **kwargs) -> LocalIterator[dict]: assert ( len(kwargs) == 0 ), "Apex execution_plan does NOT take any additional parameters" # Create a number of replay buffer actors. num_replay_buffer_shards = config["optimizer"][ "num_replay_buffer_shards"] buffer_size = (config["replay_buffer_config"]["capacity"] // num_replay_buffer_shards) replay_actor_args = [ num_replay_buffer_shards, config["learning_starts"], buffer_size, config["train_batch_size"], config["replay_buffer_config"]["prioritized_replay_alpha"], config["replay_buffer_config"]["prioritized_replay_beta"], config["replay_buffer_config"]["prioritized_replay_eps"], config["multiagent"]["replay_mode"], config["replay_buffer_config"].get("replay_sequence_length", 1), ] # Place all replay buffer shards on the same node as the learner # (driver process that runs this execution plan). if config["replay_buffer_shards_colocated_with_driver"]: replay_actors = create_colocated_actors( actor_specs=[ # (class, args, kwargs={}, count) (ReplayActor, replay_actor_args, {}, num_replay_buffer_shards) ], node=platform.node(), # localhost )[0] # [0]=only one item in `actor_specs`. # Place replay buffer shards on any node(s). else: replay_actors = [ ReplayActor(*replay_actor_args) for _ in range(num_replay_buffer_shards) ] # Start the learner thread. learner_thread = LearnerThread(workers.local_worker()) learner_thread.start() # Update experience priorities post learning. def update_prio_and_stats( item: Tuple[ActorHandle, dict, int, int]) -> None: actor, prio_dict, env_count, agent_count = item if config.get("prioritized_replay"): actor.update_priorities.remote(prio_dict) metrics = _get_shared_metrics() # Manually update the steps trained counter since the learner # thread is executing outside the pipeline. metrics.counters[STEPS_TRAINED_THIS_ITER_COUNTER] = env_count metrics.counters[STEPS_TRAINED_COUNTER] += env_count metrics.timers["learner_dequeue"] = learner_thread.queue_timer metrics.timers["learner_grad"] = learner_thread.grad_timer metrics.timers["learner_overall"] = learner_thread.overall_timer # We execute the following steps concurrently: # (1) Generate rollouts and store them in one of our replay buffer # actors. Update the weights of the worker that generated the batch. rollouts = ParallelRollouts(workers, mode="async", num_async=2) store_op = rollouts.for_each(StoreToReplayBuffer(actors=replay_actors)) # Only need to update workers if there are remote workers. if workers.remote_workers(): store_op = store_op.zip_with_source_actor().for_each( UpdateWorkerWeights( learner_thread, workers, max_weight_sync_delay=( config["optimizer"]["max_weight_sync_delay"]), )) # (2) Read experiences from one of the replay buffer actors and send # to the learner thread via its in-queue. post_fn = config.get("before_learn_on_batch") or (lambda b, *a: b) replay_op = (Replay( actors=replay_actors, num_async=4).for_each(lambda x: post_fn( x, workers, config)).zip_with_source_actor().for_each( Enqueue(learner_thread.inqueue))) # (3) Get priorities back from learner thread and apply them to the # replay buffer actors. update_op = (Dequeue(learner_thread.outqueue, check=learner_thread.is_alive).for_each( update_prio_and_stats).for_each( UpdateTargetNetwork( workers, config["target_network_update_freq"], by_steps_trained=True))) if config["training_intensity"]: # Execute (1), (2) with a fixed intensity ratio. rr_weights = calculate_rr_weights(config) + ["*"] merged_op = Concurrently( [store_op, replay_op, update_op], mode="round_robin", output_indexes=[2], round_robin_weights=rr_weights, ) else: # Execute (1), (2), (3) asynchronously as fast as possible. Only # output items from (3) since metrics aren't available before # then. merged_op = Concurrently([store_op, replay_op, update_op], mode="async", output_indexes=[2]) # Add in extra replay and learner metrics to the training result. def add_apex_metrics(result: dict) -> dict: replay_stats = ray.get(replay_actors[0].stats.remote( config["optimizer"].get("debug"))) exploration_infos = workers.foreach_policy_to_train( lambda p, _: p.get_exploration_state()) result["info"].update({ "exploration_infos": exploration_infos, "learner_queue": learner_thread.learner_queue_size.stats(), LEARNER_INFO: copy.deepcopy(learner_thread.learner_info), "replay_shard_0": replay_stats, }) return result # Only report metrics from the workers with the lowest 1/3 of # epsilons. selected_workers = workers.remote_workers( )[-len(workers.remote_workers()) // 3:] return StandardMetricsReporting( merged_op, workers, config, selected_workers=selected_workers).for_each(add_apex_metrics)
def execution_plan(workers, config): if config.get("prioritized_replay"): prio_args = { "prioritized_replay_alpha": config["prioritized_replay_alpha"], "prioritized_replay_beta": config["prioritized_replay_beta"], "prioritized_replay_eps": config["prioritized_replay_eps"], } else: prio_args = {} local_replay_buffer = LocalReplayBuffer( num_shards=1, learning_starts=config["learning_starts"], buffer_size=config["buffer_size"], replay_batch_size=config["train_batch_size"], multiagent_sync_replay=config.get("multiagent_sync_replay"), **prio_args) rollouts = ParallelRollouts(workers, mode="bulk_sync") # We execute the following steps concurrently: # (1) Generate rollouts and store them in our local replay buffer. Calling # next() on store_op drives this. store_op = rollouts.for_each( StoreToReplayBuffer(local_buffer=local_replay_buffer)) def update_prio(item): samples, info_dict = item if config.get("prioritized_replay"): prio_dict = {} for policy_id, info in info_dict.items(): # TODO(sven): This is currently structured differently for # torch/tf. Clean up these results/info dicts across # policies (note: fixing this in torch_policy.py will # break e.g. DDPPO!). td_error = info.get("td_error", info[LEARNER_STATS_KEY].get("td_error")) prio_dict[policy_id] = (samples.policy_batches[policy_id].data. get("batch_indexes"), td_error) local_replay_buffer.update_priorities(prio_dict) return info_dict # (2) Read and train on experiences from the replay buffer. Every batch # returned from the LocalReplay() iterator is passed to TrainOneStep to # take a SGD step, and then we decide whether to update the target network. post_fn = config.get("before_learn_on_batch") or (lambda b, *a: b) replay_op = Replay(local_buffer=local_replay_buffer) \ .for_each(lambda x: post_fn(x, workers, config)) \ .for_each(TrainOneStep(workers)) \ .for_each(update_prio) \ .for_each(UpdateTargetNetwork( workers, config["target_network_update_freq"])) # Alternate deterministically between (1) and (2). Only return the output # of (2) since training metrics are not available until (2) runs. train_op = Concurrently([store_op, replay_op], mode="round_robin", output_indexes=[1], round_robin_weights=calculate_rr_weights(config)) return StandardMetricsReporting(train_op, workers, config)