def sync_ensemble(workers: WorkerSet) -> None: """Syncs dynamics ensemble weights from driver (main) to workers. Args: workers: Set of workers, including driver (main). """ def get_ensemble_weights(worker): policy_map = worker.policy_map policies = policy_map.keys() def policy_ensemble_weights(policy): model = policy.dynamics_model return { k: v.cpu().detach().numpy() for k, v in model.state_dict().items() } return { pid: policy_ensemble_weights(policy) for pid, policy in policy_map.items() if pid in policies } def set_ensemble_weights(policy, pid, weights): weights = weights[pid] weights = convert_to_torch_tensor(weights, device=policy.device) model = policy.dynamics_model model.load_state_dict(weights) if workers.remote_workers(): weights = ray.put(get_ensemble_weights(workers.local_worker())) set_func = ray.put(set_ensemble_weights) for e in workers.remote_workers(): e.foreach_policy.remote(set_func, weights=weights)
def ParallelRollouts(workers: WorkerSet, mode="bulk_sync") -> LocalIterator[SampleBatch]: """Operator to collect experiences in parallel from rollout workers. If there are no remote workers, experiences will be collected serially from the local worker instance instead. Arguments: workers (WorkerSet): set of rollout workers to use. mode (str): One of {'async', 'bulk_sync'}. - In 'async' mode, batches are returned as soon as they are computed by rollout workers with no order guarantees. - In 'bulk_sync' mode, we collect one batch from each worker and concatenate them together into a large batch to return. Returns: A local iterator over experiences collected in parallel. Examples: >>> rollouts = ParallelRollouts(workers, mode="async") >>> batch = next(rollouts) >>> print(batch.count) 50 # config.sample_batch_size >>> rollouts = ParallelRollouts(workers, mode="bulk_sync") >>> batch = next(rollouts) >>> print(batch.count) 200 # config.sample_batch_size * config.num_workers Updates the STEPS_SAMPLED_COUNTER counter in the local iterator context. """ def report_timesteps(batch): metrics = LocalIterator.get_metrics() metrics.counters[STEPS_SAMPLED_COUNTER] += batch.count return batch if not workers.remote_workers(): # Handle the serial sampling case. def sampler(_): while True: yield workers.local_worker().sample() return (LocalIterator(sampler, MetricsContext()).for_each(report_timesteps)) # Create a parallel iterator over generated experiences. rollouts = from_actors(workers.remote_workers()) if mode == "bulk_sync": return rollouts \ .batch_across_shards() \ .for_each(lambda batches: SampleBatch.concat_samples(batches)) \ .for_each(report_timesteps) elif mode == "async": return rollouts.gather_async().for_each(report_timesteps) else: raise ValueError( "mode must be one of 'bulk_sync', 'async', got '{}'".format(mode))
def synchronous_parallel_sample(workers: WorkerSet) -> List[SampleBatch]: # No remote workers in the set -> Use local worker for collecting # samples. if not workers.remote_workers(): return [workers.local_worker().sample()] # Loop over remote workers' `sample()` method in parallel. sample_batches = ray.get( [r.sample.remote() for r in workers.remote_workers()]) return sample_batches
def set_policy_with_env_fn(worker_set: WorkerSet, fn_type: str = "reward"): """Set the desired environment function for all policies in the worker set. Args: worker_set: A worker set instance, usually from a trainer fn_type: The type of environment function, either 'reward', 'termination', or 'dynamics' from_env: Whether to retrieve the function from the environment instance or from the global registry """ worker_set.foreach_worker(lambda w: w.foreach_policy( lambda p, _: _set_from_env_if_possible(p, w.env, fn_type)))
def sync_stats(workers: WorkerSet) -> None: def get_normalizations(worker): policy = worker.policy_map[DEFAULT_POLICY_ID] return policy.dynamics_model.normalizations def set_normalizations(policy, pid, normalizations): policy.dynamics_model.set_norms(normalizations) if workers.remote_workers(): normalization_dict = ray.put(get_normalizations( workers.local_worker())) set_func = ray.put(set_normalizations) for e in workers.remote_workers(): e.foreach_policy.remote(set_func, normalizations=normalization_dict)
def synchronous_parallel_sample( worker_set: WorkerSet, remote_fn: Optional[Callable[["RolloutWorker"], None]] = None, ) -> List[SampleBatch]: """Runs parallel and synchronous rollouts on all remote workers. Waits for all workers to return from the remote calls. If no remote workers exist (num_workers == 0), use the local worker for sampling. Alternatively to calling `worker.sample.remote()`, the user can provide a `remote_fn()`, which will be applied to the worker(s) instead. Args: worker_set: The WorkerSet to use for sampling. remote_fn: If provided, use `worker.apply.remote(remote_fn)` instead of `worker.sample.remote()` to generate the requests. Returns: The list of collected sample batch types (one for each parallel rollout worker in the given `worker_set`). Examples: >>> # Define an RLlib trainer. >>> trainer = ... # doctest: +SKIP >>> # 2 remote workers (num_workers=2): >>> batches = synchronous_parallel_sample(trainer.workers) # doctest: +SKIP >>> print(len(batches)) # doctest: +SKIP 2 >>> print(batches[0]) # doctest: +SKIP SampleBatch(16: ['obs', 'actions', 'rewards', 'dones']) >>> # 0 remote workers (num_workers=0): Using the local worker. >>> batches = synchronous_parallel_sample(trainer.workers) # doctest: +SKIP >>> print(len(batches)) # doctest: +SKIP 1 """ # No remote workers in the set -> Use local worker for collecting # samples. if not worker_set.remote_workers(): return [worker_set.local_worker().sample()] # Loop over remote workers' `sample()` method in parallel. sample_batches = ray.get( [r.sample.remote() for r in worker_set.remote_workers()]) # Return all collected batches. return sample_batches
def __init__(self, *, workers: WorkerSet, sgd_minibatch_size: int, num_sgd_iter: int, num_gpus: int, shuffle_sequences: bool, _fake_gpus: bool = False, framework: str = "tf"): self.workers = workers self.local_worker = workers.local_worker() self.num_sgd_iter = num_sgd_iter self.sgd_minibatch_size = sgd_minibatch_size self.shuffle_sequences = shuffle_sequences self.framework = framework # Collect actual GPU devices to use. if not num_gpus: _fake_gpus = True num_gpus = 1 type_ = "cpu" if _fake_gpus else "gpu" self.devices = [ "/{}:{}".format(type_, 0 if _fake_gpus else i) for i in range(int(math.ceil(num_gpus))) ] # Make sure total batch size is dividable by the number of devices. # Batch size per tower. self.per_device_batch_size = sgd_minibatch_size // len(self.devices) # Total batch size. self.batch_size = self.per_device_batch_size * len(self.devices) assert self.batch_size % len(self.devices) == 0 assert self.batch_size >= len(self.devices), "Batch size too small!"
def gather_experiences_tree_aggregation(workers: WorkerSet, config: Dict) -> "LocalIterator[Any]": """Tree aggregation version of gather_experiences_directly().""" rollouts = ParallelRollouts(workers, mode="raw") # Divide up the workers between aggregators. worker_assignments = [[] for _ in range(config["num_aggregation_workers"])] i = 0 for w in range(len(workers.remote_workers())): worker_assignments[i].append(w) i += 1 i %= len(worker_assignments) logger.info("Worker assignments: {}".format(worker_assignments)) # Create parallel iterators that represent each aggregation group. rollout_groups: List["ParallelIterator[SampleBatchType]"] = [ rollouts.select_shards(assigned) for assigned in worker_assignments ] # This spawns |num_aggregation_workers| intermediate actors that aggregate # experiences in parallel. We force colocation on the same node to maximize # data bandwidth between them and the driver. train_batches = from_actors([ create_colocated(Aggregator, [config, g], 1)[0] for g in rollout_groups ]) # TODO(ekl) properly account for replay. def record_steps_sampled(batch): metrics = _get_shared_metrics() metrics.counters[STEPS_SAMPLED_COUNTER] += batch.count return batch return train_batches.gather_async().for_each(record_steps_sampled)
def test_reject_bad_configs(self): local, remotes = self._make_envs() workers = WorkerSet._from_existing(local, remotes) self.assertRaises( ValueError, lambda: AsyncSamplesOptimizer( local, remotes, num_data_loader_buffers=2, minibatch_buffer_size=4)) optimizer = AsyncSamplesOptimizer( workers, num_gpus=1, train_batch_size=100, rollout_fragment_length=50, _fake_gpus=True) self._wait_for(optimizer, 1000, 1000) optimizer = AsyncSamplesOptimizer( workers, num_gpus=1, train_batch_size=100, rollout_fragment_length=25, _fake_gpus=True) self._wait_for(optimizer, 1000, 1000) optimizer = AsyncSamplesOptimizer( workers, num_gpus=1, train_batch_size=100, rollout_fragment_length=74, _fake_gpus=True) self._wait_for(optimizer, 1000, 1000)
def LocalComputeUpdates(workers: WorkerSet, significance_threshold): rollouts = from_actors(workers.remote_workers()) def train_on_batch(samples): if isinstance(samples, SampleBatch): samples = MultiAgentBatch({DEFAULT_POLICY_ID: samples}, samples.count) worker = get_global_worker() if not hasattr(worker, 'num_iterations_trained'): worker.num_iterations_trained = 0 info = worker.learn_on_batch(samples) worker.foreach_trainable_policy( lambda p, pid: p.asp_accumulate_grads()) worker.num_iterations_trained += 1 info['num_iterations_trained'] = worker.num_iterations_trained updates = { pid: worker.get_policy(pid).asp_get_updates(significance_threshold) for pid in worker.policies_to_train } return updates, info, samples.count, 1 res = rollouts.for_each(train_on_batch) return res
def LocalTrainOneStep(workers: WorkerSet, num_sgd_iter: int = 1, sgd_minibatch_size: int = 0): rollouts = from_actors(workers.remote_workers()) def train_on_batch(samples): if isinstance(samples, SampleBatch): samples = MultiAgentBatch({DEFAULT_POLICY_ID: samples}, samples.count) worker = get_global_worker() if not hasattr(worker, 'num_iterations_trained'): worker.num_iterations_trained = 0 if num_sgd_iter > 1: info = do_minibatch_sgd(samples, { pid: worker.get_policy(pid) for pid in worker.policies_to_train }, worker, num_sgd_iter, sgd_minibatch_size, []) else: info = worker.learn_on_batch(samples) worker.num_iterations_trained += 1 info['num_iterations_trained'] = worker.num_iterations_trained return info, samples.count, num_sgd_iter info = rollouts.for_each(train_on_batch) return info
def test_train_external_multi_agent_cartpole_many_policies(self): n = 20 single_env = gym.make("CartPole-v0") act_space = single_env.action_space obs_space = single_env.observation_space policies = {} for i in range(20): policies["pg_{}".format(i)] = (PGTFPolicy, obs_space, act_space, {}) policy_ids = list(policies.keys()) ev = RolloutWorker( env_creator=lambda _: MultiAgentCartPole({"num_agents": n}), policy=policies, policy_mapping_fn=lambda agent_id: random.choice(policy_ids), rollout_fragment_length=100) optimizer = SyncSamplesOptimizer(WorkerSet._from_existing(ev)) for i in range(100): optimizer.step() result = collect_metrics(ev) print("Iteration {}, rew {}".format(i, result["policy_reward_mean"])) print("Total reward", result["episode_reward_mean"]) if result["episode_reward_mean"] >= 25 * n: return raise Exception("failed to improve reward")
def _init(self, config: TrainerConfigDict, env_creator: EnvCreator): # No `get_policy_class` function. if get_policy_class is None: # Default_policy must be provided (unless in multi-agent mode, # where each policy can have its own default policy class). if not config["multiagent"]["policies"]: assert default_policy is not None # Query the function for a class to use. else: self._policy_class = get_policy_class(config) # If None returned, use default policy (must be provided). if self._policy_class is None: assert default_policy is not None self._policy_class = default_policy if before_init: before_init(self) # Creating all workers (excluding evaluation workers). self.workers = WorkerSet( env_creator=env_creator, validate_env=validate_env, policy_class=self._policy_class, trainer_config=config, num_workers=self.config["num_workers"], ) self.train_exec_impl = self.execution_plan( self.workers, config, **self._kwargs_for_execution_plan()) if after_init: after_init(self)
def __init__(self, num_sets, env_creator, policy, trainer_config=None, num_workers_per_set=0, logdir=None, _setup=True): self._worker_sets = {} for i in range(num_sets): self._worker_sets[i] = WorkerSet( env_creator, policy, trainer_config, num_workers_per_set, logdir, _setup)
def _testWithOptimizer(self, optimizer_cls): n = 3 env = gym.make("CartPole-v0") act_space = env.action_space obs_space = env.observation_space dqn_config = {"gamma": 0.95, "n_step": 3} if optimizer_cls == SyncReplayOptimizer: # TODO: support replay with non-DQN graphs. Currently this can't # happen since the replay buffer doesn't encode extra fields like # "advantages" that PG uses. policies = { "p1": (DQNTFPolicy, obs_space, act_space, dqn_config), "p2": (DQNTFPolicy, obs_space, act_space, dqn_config), } else: policies = { "p1": (PGTFPolicy, obs_space, act_space, {}), "p2": (DQNTFPolicy, obs_space, act_space, dqn_config), } worker = RolloutWorker( env_creator=lambda _: MultiCartpole(n), policy=policies, policy_mapping_fn=lambda agent_id: ["p1", "p2"][agent_id % 2], batch_steps=50) if optimizer_cls == AsyncGradientsOptimizer: def policy_mapper(agent_id): return ["p1", "p2"][agent_id % 2] remote_workers = [ RolloutWorker.as_remote().remote( env_creator=lambda _: MultiCartpole(n), policy=policies, policy_mapping_fn=policy_mapper, batch_steps=50) ] else: remote_workers = [] workers = WorkerSet._from_existing(worker, remote_workers) optimizer = optimizer_cls(workers) for i in range(200): worker.foreach_policy( lambda p, _: p.set_epsilon(max(0.02, 1 - i * .02)) if isinstance(p, DQNTFPolicy) else None) optimizer.step() result = collect_metrics(worker, remote_workers) if i % 20 == 0: def do_update(p): if isinstance(p, DQNTFPolicy): p.update_target() worker.foreach_policy(lambda p, _: do_update(p)) print("Iter {}, rew {}".format(i, result["policy_reward_mean"])) print("Total reward", result["episode_reward_mean"]) if result["episode_reward_mean"] >= 25 * n: return print(result) raise Exception("failed to improve reward")
def test_train_multi_cartpole_many_policies(self): n = 20 env = gym.make("CartPole-v0") act_space = env.action_space obs_space = env.observation_space policies = {} for i in range(20): policies["pg_{}".format(i)] = (PGTFPolicy, obs_space, act_space, {}) policy_ids = list(policies.keys()) worker = RolloutWorker( env_creator=lambda _: MultiCartpole(n), policy=policies, policy_mapping_fn=lambda agent_id: random.choice(policy_ids), batch_steps=100) workers = WorkerSet._from_existing(worker, []) optimizer = SyncSamplesOptimizer(workers) for i in range(100): optimizer.step() result = collect_metrics(worker) print("Iteration {}, rew {}".format(i, result["policy_reward_mean"])) print("Total reward", result["episode_reward_mean"]) if result["episode_reward_mean"] >= 25 * n: return raise Exception("failed to improve reward")
def _make_workers(self, env_creator, policy, config, num_workers): return WorkerSet( env_creator, policy, config, num_workers=num_workers, logdir=self.logdir)
def _make_workers(self, env_creator, policy, config, num_workers): """Default factory method for a WorkerSet running under this Trainer. Override this method by passing a custom `make_workers` into `build_trainer`. Args: env_creator (callable): A function that return and Env given an env config. policy (class): The Policy class to use for creating the policies of the workers. config (dict): The Trainer's config. num_workers (int): Number of remote rollout workers to create. 0 for local only. remote_config_updates (Optional[List[dict]]): A list of config dicts to update `config` with for each Worker (len must be same as `num_workers`). Returns: WorkerSet: The created WorkerSet. """ return WorkerSet(env_creator, policy, config, num_workers=num_workers, logdir=self.logdir)
def __init__(self, *, workers: WorkerSet, sgd_minibatch_size: int, num_sgd_iter: int, num_gpus: int, shuffle_sequences: bool, policies: List[PolicyID] = frozenset([]), _fake_gpus: bool = False, framework: str = "tf"): self.workers = workers self.policies = policies or workers.local_worker().policies_to_train self.num_sgd_iter = num_sgd_iter self.sgd_minibatch_size = sgd_minibatch_size self.shuffle_sequences = shuffle_sequences self.framework = framework # Collect actual GPU devices to use. if not num_gpus: _fake_gpus = True num_gpus = 1 type_ = "cpu" if _fake_gpus else "gpu" self.devices = [ "/{}:{}".format(type_, 0 if _fake_gpus else i) for i in range(int(math.ceil(num_gpus))) ] # Total batch size (all towers). Make sure it is dividable by # num towers. self.batch_size = int(sgd_minibatch_size / len(self.devices)) * len( self.devices) assert self.batch_size % len(self.devices) == 0 assert self.batch_size >= len(self.devices), "batch size too small" # Batch size per tower. self.per_device_batch_size = int(self.batch_size / len(self.devices)) # per-GPU graph copies created below must share vars with the policy # reuse is set to AUTO_REUSE because Adam nodes are created after # all of the device copies are created. self.optimizers = {} with self.workers.local_worker().tf_sess.graph.as_default(): with self.workers.local_worker().tf_sess.as_default(): for policy_id in self.policies: policy = self.workers.local_worker().get_policy(policy_id) with tf1.variable_scope(policy_id, reuse=tf1.AUTO_REUSE): if policy._state_inputs: rnn_inputs = policy._state_inputs + [ policy._seq_lens ] else: rnn_inputs = [] self.optimizers[policy_id] = ( LocalSyncParallelOptimizer( policy._optimizer, self.devices, list(policy._loss_input_dict_no_rnn.values()), rnn_inputs, self.per_device_batch_size, policy.copy)) self.sess = self.workers.local_worker().tf_sess self.sess.run(tf1.global_variables_initializer())
def testMultiTierAggregation(self): local, remotes = self._make_evs() workers = WorkerSet._from_existing(local, remotes) aggregators = TreeAggregator.precreate_aggregators(1) optimizer = AsyncSamplesOptimizer(workers, num_aggregation_workers=1) optimizer.aggregator.init(aggregators) self._wait_for(optimizer, 1000, 1000)
def testRejectBadConfigs(self): local, remotes = self._make_evs() workers = WorkerSet._from_existing(local, remotes) self.assertRaises( ValueError, lambda: AsyncSamplesOptimizer(local, remotes, num_data_loader_buffers=2, minibatch_buffer_size=4)) optimizer = AsyncSamplesOptimizer(workers, num_gpus=2, train_batch_size=100, sample_batch_size=50, _fake_gpus=True) self._wait_for(optimizer, 1000, 1000) optimizer = AsyncSamplesOptimizer(workers, num_gpus=2, train_batch_size=100, sample_batch_size=25, _fake_gpus=True) self._wait_for(optimizer, 1000, 1000) optimizer = AsyncSamplesOptimizer(workers, num_gpus=2, train_batch_size=100, sample_batch_size=74, _fake_gpus=True) self._wait_for(optimizer, 1000, 1000)
def testMultiTierAggregationBadConf(self): local, remotes = self._make_evs() workers = WorkerSet._from_existing(local, remotes) aggregators = TreeAggregator.precreate_aggregators(4) optimizer = AsyncSamplesOptimizer(workers, num_aggregation_workers=4) self.assertRaises(ValueError, lambda: optimizer.aggregator.init(aggregators))
def AsyncGradients( workers: WorkerSet) -> LocalIterator[Tuple[ModelGradients, int]]: """Operator to compute gradients in parallel from rollout workers. Args: workers (WorkerSet): set of rollout workers to use. Returns: A local iterator over policy gradients computed on rollout workers. Examples: >>> from ray.rllib.execution.rollout_ops import AsyncGradients >>> workers = ... # doctest: +SKIP >>> grads_op = AsyncGradients(workers) # doctest: +SKIP >>> print(next(grads_op)) # doctest: +SKIP {"var_0": ..., ...}, 50 # grads, batch count Updates the STEPS_SAMPLED_COUNTER counter and LEARNER_INFO field in the local iterator context. """ # Ensure workers are initially in sync. workers.sync_weights() # This function will be applied remotely on the workers. def samples_to_grads(samples): return get_global_worker().compute_gradients(samples), samples.count # Record learner metrics and pass through (grads, count). class record_metrics: def _on_fetch_start(self): self.fetch_start_time = time.perf_counter() def __call__(self, item): (grads, info), count = item metrics = _get_shared_metrics() metrics.counters[STEPS_SAMPLED_COUNTER] += count metrics.info[LEARNER_INFO] = ({ DEFAULT_POLICY_ID: info } if LEARNER_STATS_KEY in info else info) metrics.timers[GRAD_WAIT_TIMER].push(time.perf_counter() - self.fetch_start_time) return grads, count rollouts = from_actors(workers.remote_workers()) grads = rollouts.for_each(samples_to_grads) return grads.gather_async().for_each(record_metrics())
def test_basic(self): local = _MockWorker() remotes = ray.remote(_MockWorker) remote_workers = [remotes.remote() for i in range(5)] workers = WorkerSet._from_existing(local, remote_workers) test_optimizer = AsyncGradientsOptimizer(workers, grads_per_step=10) test_optimizer.step() self.assertTrue(all(local.get_weights() == 0))
def testMultiGPUParallelLoad(self): local, remotes = self._make_evs() workers = WorkerSet._from_existing(local, remotes) optimizer = AsyncSamplesOptimizer(workers, num_gpus=2, num_data_loader_buffers=2, _fake_gpus=True) self._wait_for(optimizer, 1000, 1000)
def testBasic(self): ray.init(num_cpus=4, object_store_memory=1000 * 1024 * 1024) local = _MockWorker() remotes = ray.remote(_MockWorker) remote_workers = [remotes.remote() for i in range(5)] workers = WorkerSet._from_existing(local, remote_workers) test_optimizer = AsyncGradientsOptimizer(workers, grads_per_step=10) test_optimizer.step() self.assertTrue(all(local.get_weights() == 0))
def __init__(self, workers: WorkerSet, policies: List[PolicyID] = frozenset([]), num_sgd_iter: int = 1, sgd_minibatch_size: int = 0): self.workers = workers self.policies = policies or workers.local_worker().policies_to_train self.num_sgd_iter = num_sgd_iter self.sgd_minibatch_size = sgd_minibatch_size
def testLearnerQueueTimeout(self): local, remotes = self._make_envs() workers = WorkerSet._from_existing(local, remotes) optimizer = AsyncSamplesOptimizer(workers, sample_batch_size=1000, train_batch_size=1000, learner_queue_timeout=1) self.assertRaises(AssertionError, lambda: self._wait_for(optimizer, 1000, 1000))
def __init__(self, workers: WorkerSet, sgd_minibatch_size: int, num_sgd_iter: int, num_gpus: int, rollout_fragment_length: int, num_envs_per_worker: int, train_batch_size: int, shuffle_sequences: bool, policies: List[PolicyID] = frozenset([]), _fake_gpus: bool = False): self.workers = workers self.policies = policies or workers.local_worker().policies_to_train self.num_sgd_iter = num_sgd_iter self.sgd_minibatch_size = sgd_minibatch_size self.shuffle_sequences = shuffle_sequences # Collect actual devices to use. if not num_gpus: _fake_gpus = True num_gpus = 1 type_ = "cpu" if _fake_gpus else "gpu" self.devices = [ "/{}:{}".format(type_, i) for i in range(int(math.ceil(num_gpus))) ] self.batch_size = int(sgd_minibatch_size / len(self.devices)) * len( self.devices) assert self.batch_size % len(self.devices) == 0 assert self.batch_size >= len(self.devices), "batch size too small" self.per_device_batch_size = int(self.batch_size / len(self.devices)) # per-GPU graph copies created below must share vars with the policy # reuse is set to AUTO_REUSE because Adam nodes are created after # all of the device copies are created. self.optimizers = {} with self.workers.local_worker().tf_sess.graph.as_default(): with self.workers.local_worker().tf_sess.as_default(): for policy_id in self.policies: policy = self.workers.local_worker().get_policy(policy_id) with tf.variable_scope(policy_id, reuse=tf.AUTO_REUSE): if policy._state_inputs: rnn_inputs = policy._state_inputs + [ policy._seq_lens ] else: rnn_inputs = [] self.optimizers[policy_id] = ( LocalSyncParallelOptimizer( policy._optimizer, self.devices, [v for _, v in policy._loss_inputs], rnn_inputs, self.per_device_batch_size, policy.copy)) self.sess = self.workers.local_worker().tf_sess self.sess.run(tf.global_variables_initializer())
def execution_plan(workers: WorkerSet, config: TrainerConfigDict, **kwargs) -> LocalIterator[dict]: rollouts = ParallelRollouts(workers, mode="async") # Collect batches for the trainable policies. rollouts = rollouts.for_each( SelectExperiences(local_worker=workers.local_worker())) # Return training metrics. return StandardMetricsReporting(rollouts, workers, config)