def __init__( self, workers: WorkerSet, policies: List[PolicyID] = frozenset([]), num_sgd_iter: int = 1, sgd_minibatch_size: int = 0, ): self.workers = workers self.local_worker = workers.local_worker() self.policies = policies self.num_sgd_iter = num_sgd_iter self.sgd_minibatch_size = sgd_minibatch_size
def __init__(self, workers: WorkerSet, target_update_freq: int, by_steps_trained: bool = False, policies: List[PolicyID] = frozenset([])): self.workers = workers self.target_update_freq = target_update_freq self.policies = (policies or workers.local_worker().policies_to_train) if by_steps_trained: self.metric = STEPS_TRAINED_COUNTER else: self.metric = STEPS_SAMPLED_COUNTER
def execution_plan(workers: WorkerSet, config: TrainerConfigDict) -> LocalIterator[dict]: """Execution plan of the DQN algorithm. Defines the distributed dataflow. Args: workers (WorkerSet): The WorkerSet for training the Polic(y/ies) of the Trainer. config (TrainerConfigDict): The trainer's configuration dict. Returns: LocalIterator[dict]: A local iterator over training metrics. """ replay_buffer_actor = ReservoirReplayActor.remote( num_shards=1, learning_starts=config["learning_starts"], buffer_size=config["buffer_size"], replay_batch_size=config["train_batch_size"], replay_mode=config["multiagent"]["replay_mode"], replay_sequence_length=config["replay_sequence_length"], ) # Store a handle for the replay buffer actor in the local worker workers.local_worker().replay_buffer_actor = replay_buffer_actor # Read and train on experiences from the replay buffer. Every batch # returned from the Replay iterator is passed to TrainOneStep to # take a SGD step. post_fn = config.get("before_learn_on_batch") or (lambda b, *a: b) print("running replay op..") def gen_replay(_): while True: item = ray.get(replay_buffer_actor.replay.remote()) if item is None: yield _NextValueNotReady() else: yield item replay_op = LocalIterator(gen_replay, SharedMetrics()) \ .for_each(lambda x: post_fn(x, workers, config)) \ .for_each(TrainOneStep(workers)) replay_op = StandardMetricsReporting(replay_op, workers, config) replay_op = map( lambda x: x if not isinstance(x, _NextValueNotReady) else {}, replay_op) return replay_op
def __init__(self, *, workers: WorkerSet, sgd_minibatch_size: int, num_sgd_iter: int, num_gpus: int, shuffle_sequences: bool, policies: List[PolicyID] = frozenset([]), _fake_gpus: bool = False, framework: str = "tf"): self.workers = workers self.local_worker = workers.local_worker() self.policies = policies self.num_sgd_iter = num_sgd_iter self.sgd_minibatch_size = sgd_minibatch_size self.shuffle_sequences = shuffle_sequences self.framework = framework # Collect actual GPU devices to use. if not num_gpus: _fake_gpus = True num_gpus = 1 type_ = "cpu" if _fake_gpus else "gpu" self.devices = [ "/{}:{}".format(type_, 0 if _fake_gpus else i) for i in range(int(math.ceil(num_gpus))) ] # Total batch size (all towers). Make sure it is dividable by # num towers. self.batch_size = int(sgd_minibatch_size / len(self.devices)) * len( self.devices) assert self.batch_size % len(self.devices) == 0 assert self.batch_size >= len(self.devices), "batch size too small" # Batch size per tower. self.per_device_batch_size = int(self.batch_size / len(self.devices)) # per-GPU graph copies created below must share vars with the policy # reuse is set to AUTO_REUSE because Adam nodes are created after # all of the device copies are created. self.optimizers = {} with self.workers.local_worker().tf_sess.graph.as_default(): with self.workers.local_worker().tf_sess.as_default(): for policy_id in (self.policies or self.local_worker.policies_to_train): self.add_optimizer(policy_id) self.sess = self.workers.local_worker().tf_sess self.sess.run(tf1.global_variables_initializer())
def sync_stats(workers: WorkerSet) -> None: def get_normalizations(worker): policy = worker.policy_map[DEFAULT_POLICY_ID] return policy.dynamics_model.normalizations def set_normalizations(policy, pid, normalizations): policy.dynamics_model.set_norms(normalizations) if workers.remote_workers(): normalization_dict = ray.put(get_normalizations( workers.local_worker())) set_func = ray.put(set_normalizations) for e in workers.remote_workers(): e.foreach_policy.remote(set_func, normalizations=normalization_dict)
def synchronous_parallel_sample( worker_set: WorkerSet, remote_fn: Optional[Callable[["RolloutWorker"], None]] = None, ) -> List[SampleBatch]: """Runs parallel and synchronous rollouts on all remote workers. Waits for all workers to return from the remote calls. If no remote workers exist (num_workers == 0), use the local worker for sampling. Alternatively to calling `worker.sample.remote()`, the user can provide a `remote_fn()`, which will be applied to the worker(s) instead. Args: worker_set: The WorkerSet to use for sampling. remote_fn: If provided, use `worker.apply.remote(remote_fn)` instead of `worker.sample.remote()` to generate the requests. Returns: The list of collected sample batch types (one for each parallel rollout worker in the given `worker_set`). Examples: >>> # Define an RLlib trainer. >>> trainer = ... # doctest: +SKIP >>> # 2 remote workers (num_workers=2): >>> batches = synchronous_parallel_sample(trainer.workers) # doctest: +SKIP >>> print(len(batches)) # doctest: +SKIP 2 >>> print(batches[0]) # doctest: +SKIP SampleBatch(16: ['obs', 'actions', 'rewards', 'dones']) >>> # 0 remote workers (num_workers=0): Using the local worker. >>> batches = synchronous_parallel_sample(trainer.workers) # doctest: +SKIP >>> print(len(batches)) # doctest: +SKIP 1 """ # No remote workers in the set -> Use local worker for collecting # samples. if not worker_set.remote_workers(): return [worker_set.local_worker().sample()] # Loop over remote workers' `sample()` method in parallel. sample_batches = ray.get( [r.sample.remote() for r in worker_set.remote_workers()]) # Return all collected batches. return sample_batches
def execution_plan(workers: WorkerSet, config: TrainerConfigDict, **kwargs) -> LocalIterator[dict]: assert (len(kwargs) == 0 ), "PPO execution_plan does NOT take any additional parameters" rollouts = ParallelRollouts(workers, mode="bulk_sync") # Collect batches for the trainable policies. rollouts = rollouts.for_each( SelectExperiences(local_worker=workers.local_worker())) # Concatenate the SampleBatches into one. rollouts = rollouts.combine( ConcatBatches( min_batch_size=config["train_batch_size"], count_steps_by=config["multiagent"]["count_steps_by"], )) # Standardize advantages. rollouts = rollouts.for_each(StandardizeFields(["advantages"])) # Perform one training step on the combined + standardized batch. if config["simple_optimizer"]: train_op = rollouts.for_each( TrainOneStep( workers, num_sgd_iter=config["num_sgd_iter"], sgd_minibatch_size=config["sgd_minibatch_size"], )) else: train_op = rollouts.for_each( MultiGPUTrainOneStep( workers=workers, sgd_minibatch_size=config["sgd_minibatch_size"], num_sgd_iter=config["num_sgd_iter"], num_gpus=config["num_gpus"], _fake_gpus=config["_fake_gpus"], )) # Update KL after each round of training. train_op = train_op.for_each(lambda t: t[1]).for_each( UpdateKL(workers)) # Warn about bad reward scales and return training metrics. return StandardMetricsReporting(train_op, workers, config).for_each( lambda result: warn_about_bad_reward_scales(config, result))
def __init__( self, *, workers: WorkerSet, sgd_minibatch_size: int, num_sgd_iter: int, num_gpus: int, _fake_gpus: bool = False, # Deprecated args. shuffle_sequences=DEPRECATED_VALUE, framework=DEPRECATED_VALUE): if framework != DEPRECATED_VALUE or shuffle_sequences != DEPRECATED_VALUE: deprecation_warning( old= "MultiGPUTrainOneStep(framework=..., shuffle_sequences=...)", error=False, ) self.workers = workers self.local_worker = workers.local_worker() self.num_sgd_iter = num_sgd_iter self.sgd_minibatch_size = sgd_minibatch_size self.shuffle_sequences = shuffle_sequences # Collect actual GPU devices to use. if not num_gpus: _fake_gpus = True num_gpus = 1 type_ = "cpu" if _fake_gpus else "gpu" self.devices = [ "/{}:{}".format(type_, 0 if _fake_gpus else i) for i in range(int(math.ceil(num_gpus))) ] # Make sure total batch size is dividable by the number of devices. # Batch size per tower. self.per_device_batch_size = sgd_minibatch_size // len(self.devices) # Total batch size. self.batch_size = self.per_device_batch_size * len(self.devices) assert self.batch_size % len(self.devices) == 0 assert self.batch_size >= len(self.devices), "Batch size too small!"
def apex_execution_plan(workers: WorkerSet, config: dict) -> LocalIterator[dict]: # Create a number of replay buffer actors. num_replay_buffer_shards = config["optimizer"]["num_replay_buffer_shards"] replay_actors = create_colocated(ReplayActor, [ num_replay_buffer_shards, config["learning_starts"], config["buffer_size"], config["train_batch_size"], config["prioritized_replay_alpha"], config["prioritized_replay_beta"], config["prioritized_replay_eps"], config["multiagent"]["replay_mode"], config.get("replay_sequence_length", 1), ], num_replay_buffer_shards) # Start the learner thread. learner_thread = LearnerThread(workers.local_worker()) learner_thread.start() # Update experience priorities post learning. def update_prio_and_stats(item: Tuple["ActorHandle", dict, int]) -> None: actor, prio_dict, count = item actor.update_priorities.remote(prio_dict) metrics = _get_shared_metrics() # Manually update the steps trained counter since the learner thread # is executing outside the pipeline. metrics.counters[STEPS_TRAINED_COUNTER] += count metrics.timers["learner_dequeue"] = learner_thread.queue_timer metrics.timers["learner_grad"] = learner_thread.grad_timer metrics.timers["learner_overall"] = learner_thread.overall_timer # We execute the following steps concurrently: # (1) Generate rollouts and store them in our replay buffer actors. Update # the weights of the worker that generated the batch. rollouts = ParallelRollouts(workers, mode="async", num_async=2) store_op = rollouts \ .for_each(StoreToReplayBuffer(actors=replay_actors)) # Only need to update workers if there are remote workers. if workers.remote_workers(): store_op = store_op.zip_with_source_actor() \ .for_each(UpdateWorkerWeights( learner_thread, workers, max_weight_sync_delay=( config["optimizer"]["max_weight_sync_delay"]) )) # (2) Read experiences from the replay buffer actors and send to the # learner thread via its in-queue. post_fn = config.get("before_learn_on_batch") or (lambda b, *a: b) replay_op = Replay(actors=replay_actors, num_async=4) \ .for_each(lambda x: post_fn(x, workers, config)) \ .zip_with_source_actor() \ .for_each(Enqueue(learner_thread.inqueue)) # (3) Get priorities back from learner thread and apply them to the # replay buffer actors. update_op = Dequeue( learner_thread.outqueue, check=learner_thread.is_alive) \ .for_each(update_prio_and_stats) \ .for_each(UpdateTargetNetwork( workers, config["target_network_update_freq"], by_steps_trained=True)) if config["training_intensity"]: # Execute (1), (2) with a fixed intensity ratio. rr_weights = calculate_rr_weights(config) + ["*"] merged_op = Concurrently( [store_op, replay_op, update_op], mode="round_robin", output_indexes=[2], round_robin_weights=rr_weights) else: # Execute (1), (2), (3) asynchronously as fast as possible. Only output # items from (3) since metrics aren't available before then. merged_op = Concurrently( [store_op, replay_op, update_op], mode="async", output_indexes=[2]) # Add in extra replay and learner metrics to the training result. def add_apex_metrics(result: dict) -> dict: replay_stats = ray.get(replay_actors[0].stats.remote( config["optimizer"].get("debug"))) exploration_infos = workers.foreach_trainable_policy( lambda p, _: p.get_exploration_info()) result["info"].update({ "exploration_infos": exploration_infos, "learner_queue": learner_thread.learner_queue_size.stats(), "learner": copy.deepcopy(learner_thread.stats), "replay_shard_0": replay_stats, }) return result # Only report metrics from the workers with the lowest 1/3 of epsilons. selected_workers = workers.remote_workers()[ -len(workers.remote_workers()) // 3:] return StandardMetricsReporting( merged_op, workers, config, selected_workers=selected_workers).for_each(add_apex_metrics)
def execution_plan(workers: WorkerSet, config: AlgorithmConfigDict, **kwargs) -> LocalIterator[dict]: assert ( len(kwargs) == 0 ), "MBMPO execution_plan does NOT take any additional parameters" # Train TD Models on the driver. workers.local_worker().foreach_policy(fit_dynamics) # Sync driver's policy with workers. workers.sync_weights() # Sync TD Models and normalization stats with workers sync_ensemble(workers) sync_stats(workers) # Dropping metrics from the first iteration _, _ = collect_episodes(workers.local_worker(), workers.remote_workers(), [], timeout_seconds=9999) # Metrics Collector. metric_collect = CollectMetrics( workers, min_history=0, timeout_seconds=config["metrics_episode_collection_timeout_s"], ) num_inner_steps = config["inner_adaptation_steps"] def inner_adaptation_steps(itr): buf = [] split = [] metrics = {} for samples in itr: print("Collecting Samples, Inner Adaptation {}".format( len(split))) # Processing Samples (Standardize Advantages) samples, split_lst = post_process_samples(samples, config) buf.extend(samples) split.append(split_lst) adapt_iter = len(split) - 1 prefix = "DynaTrajInner_" + str(adapt_iter) metrics = post_process_metrics(prefix, workers, metrics) if len(split) > num_inner_steps: out = SampleBatch.concat_samples(buf) out["split"] = np.array(split) buf = [] split = [] yield out, metrics metrics = {} else: inner_adaptation(workers, samples) # Iterator for Inner Adaptation Data gathering (from pre->post # adaptation). rollouts = from_actors(workers.remote_workers()) rollouts = rollouts.batch_across_shards() rollouts = rollouts.transform(inner_adaptation_steps) # Meta update step with outer combine loop for multiple MAML # iterations. train_op = rollouts.combine( MetaUpdate( workers, config["num_maml_steps"], config["maml_optimizer_steps"], metric_collect, )) return train_op
def execution_plan(workers: WorkerSet, config: TrainerConfigDict) -> LocalIterator[dict]: """Execution plan of the PPO algorithm. Defines the distributed dataflow. Args: workers (WorkerSet): The WorkerSet for training the Polic(y/ies) of the Trainer. config (TrainerConfigDict): The trainer's configuration dict. Returns: LocalIterator[dict]: The Policy class to use with PPOTrainer. If None, use `default_policy` provided in build_trainer(). """ # Train TD Models on the driver. workers.local_worker().foreach_policy(fit_dynamics) # Sync driver's policy with workers. workers.sync_weights() # Sync TD Models and normalization stats with workers sync_ensemble(workers) sync_stats(workers) # Dropping metrics from the first iteration _, _ = collect_episodes(workers.local_worker(), workers.remote_workers(), [], timeout_seconds=9999) # Metrics Collector. metric_collect = CollectMetrics( workers, min_history=0, timeout_seconds=config["collect_metrics_timeout"]) num_inner_steps = config["inner_adaptation_steps"] def inner_adaptation_steps(itr): buf = [] split = [] metrics = {} for samples in itr: print("Collecting Samples, Inner Adaptation {}".format(len(split))) # Processing Samples (Standardize Advantages) samples, split_lst = post_process_samples(samples, config) buf.extend(samples) split.append(split_lst) adapt_iter = len(split) - 1 prefix = "DynaTrajInner_" + str(adapt_iter) metrics = post_process_metrics(prefix, workers, metrics) if len(split) > num_inner_steps: out = SampleBatch.concat_samples(buf) out["split"] = np.array(split) buf = [] split = [] yield out, metrics metrics = {} else: inner_adaptation(workers, samples) # Iterator for Inner Adaptation Data gathering (from pre->post adaptation). rollouts = from_actors(workers.remote_workers()) rollouts = rollouts.batch_across_shards() rollouts = rollouts.transform(inner_adaptation_steps) # Meta update step with outer combine loop for multiple MAML iterations. train_op = rollouts.combine( MetaUpdate(workers, config["num_maml_steps"], config["maml_optimizer_steps"], metric_collect)) return train_op
def synchronous_parallel_sample( *, worker_set: WorkerSet, max_agent_steps: Optional[int] = None, max_env_steps: Optional[int] = None, concat: bool = True, ) -> Union[List[SampleBatchType], SampleBatchType]: """Runs parallel and synchronous rollouts on all remote workers. Waits for all workers to return from the remote calls. If no remote workers exist (num_workers == 0), use the local worker for sampling. Alternatively to calling `worker.sample.remote()`, the user can provide a `remote_fn()`, which will be applied to the worker(s) instead. Args: worker_set: The WorkerSet to use for sampling. remote_fn: If provided, use `worker.apply.remote(remote_fn)` instead of `worker.sample.remote()` to generate the requests. max_agent_steps: Optional number of agent steps to be included in the final batch. max_env_steps: Optional number of environment steps to be included in the final batch. concat: Whether to concat all resulting batches at the end and return the concat'd batch. Returns: The list of collected sample batch types (one for each parallel rollout worker in the given `worker_set`). Examples: >>> # Define an RLlib trainer. >>> trainer = ... # doctest: +SKIP >>> # 2 remote workers (num_workers=2): >>> batches = synchronous_parallel_sample(trainer.workers) # doctest: +SKIP >>> print(len(batches)) # doctest: +SKIP 2 >>> print(batches[0]) # doctest: +SKIP SampleBatch(16: ['obs', 'actions', 'rewards', 'dones']) >>> # 0 remote workers (num_workers=0): Using the local worker. >>> batches = synchronous_parallel_sample(trainer.workers) # doctest: +SKIP >>> print(len(batches)) # doctest: +SKIP 1 """ # Only allow one of `max_agent_steps` or `max_env_steps` to be defined. assert not (max_agent_steps is not None and max_env_steps is not None) agent_or_env_steps = 0 max_agent_or_env_steps = max_agent_steps or max_env_steps or None all_sample_batches = [] # Stop collecting batches as soon as one criterium is met. while (max_agent_or_env_steps is None and agent_or_env_steps == 0) or (max_agent_or_env_steps is not None and agent_or_env_steps < max_agent_or_env_steps): # No remote workers in the set -> Use local worker for collecting # samples. if not worker_set.remote_workers(): sample_batches = [worker_set.local_worker().sample()] # Loop over remote workers' `sample()` method in parallel. else: sample_batches = ray.get([ worker.sample.remote() for worker in worker_set.remote_workers() ]) # Update our counters for the stopping criterion of the while loop. for b in sample_batches: if max_agent_steps: agent_or_env_steps += b.agent_steps() else: agent_or_env_steps += b.env_steps() all_sample_batches.extend(sample_batches) if concat is True: full_batch = SampleBatch.concat_samples(all_sample_batches) # Discard collected incomplete episodes in episode mode. # if max_episodes is not None and episodes >= max_episodes: # last_complete_ep_idx = len(full_batch) - full_batch[ # SampleBatch.DONES # ].reverse().index(1) # full_batch = full_batch.slice(0, last_complete_ep_idx) return full_batch else: return all_sample_batches
def ParallelRollouts(workers: WorkerSet, *, mode="bulk_sync", num_async=1) -> LocalIterator[SampleBatch]: """Operator to collect experiences in parallel from rollout workers. If there are no remote workers, experiences will be collected serially from the local worker instance instead. Args: workers (WorkerSet): set of rollout workers to use. mode (str): One of 'async', 'bulk_sync', 'raw'. In 'async' mode, batches are returned as soon as they are computed by rollout workers with no order guarantees. In 'bulk_sync' mode, we collect one batch from each worker and concatenate them together into a large batch to return. In 'raw' mode, the ParallelIterator object is returned directly and the caller is responsible for implementing gather and updating the timesteps counter. num_async (int): In async mode, the max number of async requests in flight per actor. Returns: A local iterator over experiences collected in parallel. Examples: >>> from ray.rllib.execution import ParallelRollouts >>> workers = ... # doctest: +SKIP >>> rollouts = ParallelRollouts(workers, mode="async") # doctest: +SKIP >>> batch = next(rollouts) # doctest: +SKIP >>> print(batch.count) # doctest: +SKIP 50 # config.rollout_fragment_length >>> rollouts = ParallelRollouts(workers, mode="bulk_sync") # doctest: +SKIP >>> batch = next(rollouts) # doctest: +SKIP >>> print(batch.count) # doctest: +SKIP 200 # config.rollout_fragment_length * config.num_workers Updates the STEPS_SAMPLED_COUNTER counter in the local iterator context. """ # Ensure workers are initially in sync. workers.sync_weights() def report_timesteps(batch): metrics = _get_shared_metrics() metrics.counters[STEPS_SAMPLED_COUNTER] += batch.count if isinstance(batch, MultiAgentBatch): metrics.counters[AGENT_STEPS_SAMPLED_COUNTER] += batch.agent_steps( ) else: metrics.counters[AGENT_STEPS_SAMPLED_COUNTER] += batch.count return batch if not workers.remote_workers(): # Handle the `num_workers=0` case, in which the local worker # has to do sampling as well. return LocalIterator( lambda timeout: workers.local_worker().item_generator, SharedMetrics()).for_each(report_timesteps) # Create a parallel iterator over generated experiences. rollouts = from_actors(workers.remote_workers()) if mode == "bulk_sync": return (rollouts.batch_across_shards().for_each( lambda batches: SampleBatch.concat_samples(batches)).for_each( report_timesteps)) elif mode == "async": return rollouts.gather_async( num_async=num_async).for_each(report_timesteps) elif mode == "raw": return rollouts else: raise ValueError( "mode must be one of 'bulk_sync', 'async', 'raw', got '{}'".format( mode))
def execution_plan(workers: WorkerSet, config: dict, **kwargs) -> LocalIterator[dict]: assert ( len(kwargs) == 0 ), "Apex execution_plan does NOT take any additional parameters" # Create a number of replay buffer actors. num_replay_buffer_shards = config["optimizer"][ "num_replay_buffer_shards"] buffer_size = (config["replay_buffer_config"]["capacity"] // num_replay_buffer_shards) replay_actor_args = [ num_replay_buffer_shards, config["learning_starts"], buffer_size, config["train_batch_size"], config["replay_buffer_config"]["prioritized_replay_alpha"], config["replay_buffer_config"]["prioritized_replay_beta"], config["replay_buffer_config"]["prioritized_replay_eps"], config["multiagent"]["replay_mode"], config["replay_buffer_config"].get("replay_sequence_length", 1), ] # Place all replay buffer shards on the same node as the learner # (driver process that runs this execution plan). if config["replay_buffer_shards_colocated_with_driver"]: replay_actors = create_colocated_actors( actor_specs=[ # (class, args, kwargs={}, count) (ReplayActor, replay_actor_args, {}, num_replay_buffer_shards) ], node=platform.node(), # localhost )[0] # [0]=only one item in `actor_specs`. # Place replay buffer shards on any node(s). else: replay_actors = [ ReplayActor(*replay_actor_args) for _ in range(num_replay_buffer_shards) ] # Start the learner thread. learner_thread = LearnerThread(workers.local_worker()) learner_thread.start() # Update experience priorities post learning. def update_prio_and_stats( item: Tuple[ActorHandle, dict, int, int]) -> None: actor, prio_dict, env_count, agent_count = item if config.get("prioritized_replay"): actor.update_priorities.remote(prio_dict) metrics = _get_shared_metrics() # Manually update the steps trained counter since the learner # thread is executing outside the pipeline. metrics.counters[STEPS_TRAINED_THIS_ITER_COUNTER] = env_count metrics.counters[STEPS_TRAINED_COUNTER] += env_count metrics.timers["learner_dequeue"] = learner_thread.queue_timer metrics.timers["learner_grad"] = learner_thread.grad_timer metrics.timers["learner_overall"] = learner_thread.overall_timer # We execute the following steps concurrently: # (1) Generate rollouts and store them in one of our replay buffer # actors. Update the weights of the worker that generated the batch. rollouts = ParallelRollouts(workers, mode="async", num_async=2) store_op = rollouts.for_each(StoreToReplayBuffer(actors=replay_actors)) # Only need to update workers if there are remote workers. if workers.remote_workers(): store_op = store_op.zip_with_source_actor().for_each( UpdateWorkerWeights( learner_thread, workers, max_weight_sync_delay=( config["optimizer"]["max_weight_sync_delay"]), )) # (2) Read experiences from one of the replay buffer actors and send # to the learner thread via its in-queue. post_fn = config.get("before_learn_on_batch") or (lambda b, *a: b) replay_op = (Replay( actors=replay_actors, num_async=4).for_each(lambda x: post_fn( x, workers, config)).zip_with_source_actor().for_each( Enqueue(learner_thread.inqueue))) # (3) Get priorities back from learner thread and apply them to the # replay buffer actors. update_op = (Dequeue(learner_thread.outqueue, check=learner_thread.is_alive).for_each( update_prio_and_stats).for_each( UpdateTargetNetwork( workers, config["target_network_update_freq"], by_steps_trained=True))) if config["training_intensity"]: # Execute (1), (2) with a fixed intensity ratio. rr_weights = calculate_rr_weights(config) + ["*"] merged_op = Concurrently( [store_op, replay_op, update_op], mode="round_robin", output_indexes=[2], round_robin_weights=rr_weights, ) else: # Execute (1), (2), (3) asynchronously as fast as possible. Only # output items from (3) since metrics aren't available before # then. merged_op = Concurrently([store_op, replay_op, update_op], mode="async", output_indexes=[2]) # Add in extra replay and learner metrics to the training result. def add_apex_metrics(result: dict) -> dict: replay_stats = ray.get(replay_actors[0].stats.remote( config["optimizer"].get("debug"))) exploration_infos = workers.foreach_policy_to_train( lambda p, _: p.get_exploration_state()) result["info"].update({ "exploration_infos": exploration_infos, "learner_queue": learner_thread.learner_queue_size.stats(), LEARNER_INFO: copy.deepcopy(learner_thread.learner_info), "replay_shard_0": replay_stats, }) return result # Only report metrics from the workers with the lowest 1/3 of # epsilons. selected_workers = workers.remote_workers( )[-len(workers.remote_workers()) // 3:] return StandardMetricsReporting( merged_op, workers, config, selected_workers=selected_workers).for_each(add_apex_metrics)