def LocalTrainOneStep(workers: WorkerSet, num_sgd_iter: int = 1, sgd_minibatch_size: int = 0): rollouts = from_actors(workers.remote_workers()) def train_on_batch(samples): if isinstance(samples, SampleBatch): samples = MultiAgentBatch({DEFAULT_POLICY_ID: samples}, samples.count) worker = get_global_worker() if not hasattr(worker, 'num_iterations_trained'): worker.num_iterations_trained = 0 if num_sgd_iter > 1: info = do_minibatch_sgd(samples, { pid: worker.get_policy(pid) for pid in worker.policies_to_train }, worker, num_sgd_iter, sgd_minibatch_size, []) else: info = worker.learn_on_batch(samples) worker.num_iterations_trained += 1 info['num_iterations_trained'] = worker.num_iterations_trained return info, samples.count, num_sgd_iter info = rollouts.for_each(train_on_batch) return info
def _create_ml_dataset(name: str, record_pieces: List[RecordPiece], record_sizes: List[int], num_shards: int, shuffle: bool, shuffle_seed: int, RecordBatchCls, node_hints: List[str] = None) -> MLDataset: if node_hints is not None: assert num_shards % len(node_hints) == 0,\ f"num_shards: {num_shards} should be a multiple of length of node_hints: {node_hints}" if shuffle_seed: np.random.seed(shuffle_seed) else: np.random.seed(0) # split the piece into num_shards partitions divided_blocks = divide_blocks(blocks=record_sizes, world_size=num_shards, shuffle=shuffle, shuffle_seed=shuffle_seed) record_batches = [] for rank, blocks in divided_blocks.items(): pieces = [] for index, num_samples in blocks: record_size = record_sizes[index] piece = record_pieces[index] if num_samples != record_size: assert num_samples < record_size new_row_ids = np.random.choice( record_size, size=num_samples).tolist() piece = piece.with_row_ids(new_row_ids) pieces.append(piece) if shuffle: np.random.shuffle(pieces) record_batches.append(RecordBatchCls(shard_id=rank, prefix=name, record_pieces=pieces, shuffle=shuffle, shuffle_seed=shuffle_seed)) worker_cls = ray.remote(ParallelIteratorWorkerWithLen) if node_hints is not None: actors = [] multiplier = num_shards // len(node_hints) resource_keys = [f"node:{node_hints[i // multiplier]}" for i in range(num_shards)] for g, resource_key in zip(record_batches, resource_keys): actor = worker_cls.options(resources={resource_key: 0.01}).remote(g, False, len(g)) actors.append(actor) else: worker_cls = ray.remote(ParallelIteratorWorkerWithLen) actors = [worker_cls.remote(g, False, len(g)) for g in record_batches] it = parallel_it.from_actors(actors, name) ds = ml_dataset.from_parallel_iter( it, need_convert=False, batch_size=0, repeated=False) return ds
def LocalComputeUpdates(workers: WorkerSet, significance_threshold): rollouts = from_actors(workers.remote_workers()) def train_on_batch(samples): if isinstance(samples, SampleBatch): samples = MultiAgentBatch({DEFAULT_POLICY_ID: samples}, samples.count) worker = get_global_worker() if not hasattr(worker, 'num_iterations_trained'): worker.num_iterations_trained = 0 info = worker.learn_on_batch(samples) worker.foreach_trainable_policy( lambda p, pid: p.asp_accumulate_grads()) worker.num_iterations_trained += 1 info['num_iterations_trained'] = worker.num_iterations_trained updates = { pid: worker.get_policy(pid).asp_get_updates(significance_threshold) for pid in worker.policies_to_train } return updates, info, samples.count, 1 res = rollouts.for_each(train_on_batch) return res
def gather_experiences_tree_aggregation(workers, config): """Tree aggregation version of gather_experiences_directly().""" rollouts = ParallelRollouts(workers, mode="raw") # Divide up the workers between aggregators. worker_assignments = [[] for _ in range(config["num_aggregation_workers"])] i = 0 for w in range(len(workers.remote_workers())): worker_assignments[i].append(w) i += 1 i %= len(worker_assignments) logger.info("Worker assignments: {}".format(worker_assignments)) # Create parallel iterators that represent each aggregation group. rollout_groups: List["ParallelIterator[SampleBatchType]"] = [ rollouts.select_shards(assigned) for assigned in worker_assignments ] # This spawns |num_aggregation_workers| intermediate actors that aggregate # experiences in parallel. We force colocation on the same node to maximize # data bandwidth between them and the driver. train_batches = from_actors([ create_colocated(Aggregator, [config, g], 1)[0] for g in rollout_groups ]) # TODO(ekl) properly account for replay. def record_steps_sampled(batch): metrics = _get_shared_metrics() metrics.counters[STEPS_SAMPLED_COUNTER] += batch.count return batch return train_batches.gather_async().for_each(record_steps_sampled)
def execution_plan(workers, config): # Sync workers with meta policy workers.sync_weights() # Samples and sets worker tasks use_meta_env = config["use_meta_env"] set_worker_tasks(workers, use_meta_env) # Metric Collector metric_collect = CollectMetrics( workers, min_history=config["metrics_smoothing_episodes"], timeout_seconds=config["collect_metrics_timeout"]) # Iterator for Inner Adaptation Data gathering (from pre->post adaptation) inner_steps = config["inner_adaptation_steps"] def inner_adaptation_steps(itr): buf = [] split = [] metrics = {} for samples in itr: # Processing Samples (Standardize Advantages) split_lst = [] for sample in samples: sample["advantages"] = standardized(sample["advantages"]) split_lst.append(sample.count) buf.extend(samples) split.append(split_lst) adapt_iter = len(split) - 1 metrics = post_process_metrics(adapt_iter, workers, metrics) if len(split) > inner_steps: out = SampleBatch.concat_samples(buf) out["split"] = np.array(split) buf = [] split = [] # Reporting Adaptation Rew Diff ep_rew_pre = metrics["episode_reward_mean"] ep_rew_post = metrics["episode_reward_mean_adapt_" + str(inner_steps)] metrics["adaptation_delta"] = ep_rew_post - ep_rew_pre yield out, metrics metrics = {} else: inner_adaptation(workers, samples) rollouts = from_actors(workers.remote_workers()) rollouts = rollouts.batch_across_shards() rollouts = rollouts.transform(inner_adaptation_steps) # Metaupdate Step train_op = rollouts.for_each( MetaUpdate(workers, config["maml_optimizer_steps"], metric_collect, use_meta_env)) return train_op
def ParallelRollouts(workers: WorkerSet, mode="bulk_sync") -> LocalIterator[SampleBatch]: """Operator to collect experiences in parallel from rollout workers. If there are no remote workers, experiences will be collected serially from the local worker instance instead. Arguments: workers (WorkerSet): set of rollout workers to use. mode (str): One of {'async', 'bulk_sync'}. - In 'async' mode, batches are returned as soon as they are computed by rollout workers with no order guarantees. - In 'bulk_sync' mode, we collect one batch from each worker and concatenate them together into a large batch to return. Returns: A local iterator over experiences collected in parallel. Examples: >>> rollouts = ParallelRollouts(workers, mode="async") >>> batch = next(rollouts) >>> print(batch.count) 50 # config.sample_batch_size >>> rollouts = ParallelRollouts(workers, mode="bulk_sync") >>> batch = next(rollouts) >>> print(batch.count) 200 # config.sample_batch_size * config.num_workers Updates the STEPS_SAMPLED_COUNTER counter in the local iterator context. """ def report_timesteps(batch): metrics = LocalIterator.get_metrics() metrics.counters[STEPS_SAMPLED_COUNTER] += batch.count return batch if not workers.remote_workers(): # Handle the serial sampling case. def sampler(_): while True: yield workers.local_worker().sample() return (LocalIterator(sampler, MetricsContext()).for_each(report_timesteps)) # Create a parallel iterator over generated experiences. rollouts = from_actors(workers.remote_workers()) if mode == "bulk_sync": return rollouts \ .batch_across_shards() \ .for_each(lambda batches: SampleBatch.concat_samples(batches)) \ .for_each(report_timesteps) elif mode == "async": return rollouts.gather_async().for_each(report_timesteps) else: raise ValueError( "mode must be one of 'bulk_sync', 'async', got '{}'".format(mode))
def test_from_actors(ray_start_regular_shared): @ray.remote class CustomWorker(ParallelIteratorWorker): def __init__(self, data): ParallelIteratorWorker.__init__(self, data, False) a = CustomWorker.remote([1, 2]) b = CustomWorker.remote([3, 4]) it = from_actors([a, b]) assert repr(it) == "ParallelIterator[from_actors[shards=2]]" assert list(it.gather_sync()) == [1, 3, 2, 4]
def Replay( *, local_buffer: Optional[MultiAgentReplayBuffer] = None, num_items_to_replay: int = 1, actors: Optional[List[ActorHandle]] = None, num_async: int = 4, ) -> LocalIterator[SampleBatchType]: """Replay experiences from the given buffer or actors. This should be combined with the StoreToReplayActors operation using the Concurrently() operator. Args: local_buffer: Local buffer to use. Only one of this and replay_actors can be specified. num_items_to_replay: Number of items to sample from buffer actors: List of replay actors. Only one of this and local_buffer can be specified. num_async: In async mode, the max number of async requests in flight per actor. Examples: >>> from ray.rllib.utils.replay_buffers import multi_agent_replay_buffer >>> actors = [ # doctest: +SKIP ... multi_agent_replay_buffer.ReplayActor.remote() for _ in range(4)] >>> replay_op = Replay(actors=actors, # doctest: +SKIP ... num_items_to_replay=batch_size) >>> next(replay_op) # doctest: +SKIP SampleBatch(...) """ if local_buffer is not None and actors is not None: raise ValueError( "Exactly one of local_buffer and replay_actors must be given.") if actors is not None: for actor in actors: actor.make_iterator.remote(num_items_to_replay=num_items_to_replay) replay = from_actors(actors) return replay.gather_async( num_async=num_async).filter(lambda x: x is not None) def gen_replay(_): while True: item = local_buffer.sample(num_items_to_replay) if item is None: yield _NextValueNotReady() else: yield item return LocalIterator(gen_replay, SharedMetrics())
def gather_experiences_tree_aggregation(workers: WorkerSet, config: Dict) -> "LocalIterator[Any]": """Tree aggregation version of gather_experiences_directly().""" rollouts = ParallelRollouts(workers, mode="raw") # Divide up the workers between aggregators. worker_assignments = [[] for _ in range(config["num_aggregation_workers"])] i = 0 for worker_idx in range(len(workers.remote_workers())): worker_assignments[i].append(worker_idx) i += 1 i %= len(worker_assignments) logger.info("Worker assignments: {}".format(worker_assignments)) # Create parallel iterators that represent each aggregation group. rollout_groups: List["ParallelIterator[SampleBatchType]"] = [ rollouts.select_shards(assigned) for assigned in worker_assignments ] # This spawns |num_aggregation_workers| intermediate actors that aggregate # experiences in parallel. We force colocation on the same node (localhost) # to maximize data bandwidth between them and the driver. localhost = platform.node() assert localhost != "", ("ERROR: Cannot determine local node name! " "`platform.node()` returned empty string.") all_co_located = create_colocated_actors( actor_specs=[ # (class, args, kwargs={}, count=1) (Aggregator, [config, g], {}, 1) for g in rollout_groups ], node=localhost, ) # Use the first ([0]) of each created group (each group only has one # actor: count=1). train_batches = from_actors([group[0] for group in all_co_located]) # TODO(ekl) properly account for replay. def record_steps_sampled(batch): metrics = _get_shared_metrics() metrics.counters[STEPS_SAMPLED_COUNTER] += batch.count if isinstance(batch, MultiAgentBatch): metrics.counters[AGENT_STEPS_SAMPLED_COUNTER] += batch.agent_steps( ) else: metrics.counters[AGENT_STEPS_SAMPLED_COUNTER] += batch.count return batch return train_batches.gather_async().for_each(record_steps_sampled)
def AsyncGradients( workers: WorkerSet) -> LocalIterator[Tuple[ModelGradients, int]]: """Operator to compute gradients in parallel from rollout workers. Args: workers (WorkerSet): set of rollout workers to use. Returns: A local iterator over policy gradients computed on rollout workers. Examples: >>> from ray.rllib.execution.rollout_ops import AsyncGradients >>> workers = ... # doctest: +SKIP >>> grads_op = AsyncGradients(workers) # doctest: +SKIP >>> print(next(grads_op)) # doctest: +SKIP {"var_0": ..., ...}, 50 # grads, batch count Updates the STEPS_SAMPLED_COUNTER counter and LEARNER_INFO field in the local iterator context. """ # Ensure workers are initially in sync. workers.sync_weights() # This function will be applied remotely on the workers. def samples_to_grads(samples): return get_global_worker().compute_gradients(samples), samples.count # Record learner metrics and pass through (grads, count). class record_metrics: def _on_fetch_start(self): self.fetch_start_time = time.perf_counter() def __call__(self, item): (grads, info), count = item metrics = _get_shared_metrics() metrics.counters[STEPS_SAMPLED_COUNTER] += count metrics.info[LEARNER_INFO] = ({ DEFAULT_POLICY_ID: info } if LEARNER_STATS_KEY in info else info) metrics.timers[GRAD_WAIT_TIMER].push(time.perf_counter() - self.fetch_start_time) return grads, count rollouts = from_actors(workers.remote_workers()) grads = rollouts.for_each(samples_to_grads) return grads.gather_async().for_each(record_metrics())
def Replay(*, local_buffer: LocalReplayBuffer = None, actors: List["ActorHandle"] = None, num_async=4): """Replay experiences from the given buffer or actors. This should be combined with the StoreToReplayActors operation using the Concurrently() operator. Arguments: local_buffer (LocalReplayBuffer): Local buffer to use. Only one of this and replay_actors can be specified. actors (list): List of replay actors. Only one of this and local_buffer can be specified. num_async (int): In async mode, the max number of async requests in flight per actor. Examples: >>> actors = [ReplayActor.remote() for _ in range(4)] >>> replay_op = Replay(actors=actors) >>> next(replay_op) SampleBatch(...) Code From: https://github.com/ray-project/ray/blob/ray-0.8.7/rllib/execution/replay_ops.py """ if bool(local_buffer) == bool(actors): raise ValueError( "Exactly one of local_buffer and replay_actors must be given.") if actors: replay = from_actors(actors) return replay.gather_async( num_async=num_async).filter(lambda x: x is not None) def gen_replay(_): while True: item = local_buffer.replay() if item is None: yield _NextValueNotReady() else: yield item return LocalIterator(gen_replay, SharedMetrics())
def Replay(*, local_buffer: MultiAgentReplayBuffer = None, actors: List[ActorHandle] = None, num_async: int = 4) -> LocalIterator[SampleBatchType]: """Replay experiences from the given buffer or actors. This should be combined with the StoreToReplayActors operation using the Concurrently() operator. Args: local_buffer: Local buffer to use. Only one of this and replay_actors can be specified. actors: List of replay actors. Only one of this and local_buffer can be specified. num_async: In async mode, the max number of async requests in flight per actor. Examples: >>> actors = [ReplayActor.remote() for _ in range(4)] >>> replay_op = Replay(actors=actors) >>> next(replay_op) SampleBatch(...) """ if bool(local_buffer) == bool(actors): raise ValueError( "Exactly one of local_buffer and replay_actors must be given.") if actors: replay = from_actors(actors) return replay.gather_async( num_async=num_async).filter(lambda x: x is not None) def gen_replay(_): while True: item = local_buffer.replay() if item is None: yield _NextValueNotReady() else: yield item return LocalIterator(gen_replay, SharedMetrics())
def ParallelReplay(replay_actors: List["ActorHandle"], async_queue_depth=4): """Replay experiences in parallel from the given actors. This should be combined with the StoreToReplayActors operation using the Concurrently() operator. Arguments: replay_actors (list): List of replay actors. async_queue_depth (int): In async mode, the max number of async requests in flight per actor. Examples: >>> actors = [ReplayActor.remote() for _ in range(4)] >>> replay_op = ParallelReplay(actors) >>> next(replay_op) SampleBatch(...) """ replay = from_actors(replay_actors) return replay.gather_async( async_queue_depth=async_queue_depth).filter(lambda x: x is not None)
def AsyncGradients( workers: WorkerSet) -> LocalIterator[Tuple[GradientType, int]]: """Operator to compute gradients in parallel from rollout workers. Arguments: workers (WorkerSet): set of rollout workers to use. Returns: A local iterator over policy gradients computed on rollout workers. Examples: >>> grads_op = AsyncGradients(workers) >>> print(next(grads_op)) {"var_0": ..., ...}, 50 # grads, batch count Updates the STEPS_SAMPLED_COUNTER counter and LEARNER_INFO field in the local iterator context. """ # This function will be applied remotely on the workers. def samples_to_grads(samples): return get_global_worker().compute_gradients(samples), samples.count # Record learner metrics and pass through (grads, count). class record_metrics: def _on_fetch_start(self): self.fetch_start_time = time.perf_counter() def __call__(self, item): (grads, info), count = item metrics = LocalIterator.get_metrics() metrics.counters[STEPS_SAMPLED_COUNTER] += count metrics.info[LEARNER_INFO] = get_learner_stats(info) metrics.timers[GRAD_WAIT_TIMER].push(time.perf_counter() - self.fetch_start_time) return grads, count rollouts = from_actors(workers.remote_workers()) grads = rollouts.for_each(samples_to_grads) return grads.gather_async().for_each(record_metrics())
def execution_plan(workers, config): # Sync workers with meta policy workers.sync_weights() # Samples and sets worker tasks set_worker_tasks(workers) # Metric Collector metric_collect = CollectMetrics( workers, min_history=config["metrics_smoothing_episodes"], timeout_seconds=config["collect_metrics_timeout"]) # Iterator for Inner Adaptation Data gathering (from pre->post adaptation) rollouts = from_actors(workers.remote_workers()) rollouts = rollouts.batch_across_shards() rollouts = rollouts.combine( InnerAdaptationSteps(workers, config["inner_adaptation_steps"], metric_collect)) # Metaupdate Step train_op = rollouts.for_each( MetaUpdate(workers, config["maml_optimizer_steps"], metric_collect)) return train_op
def execution_plan(workers, config): # Train TD Models workers.local_worker().foreach_policy(fit_dynamics) # Sync workers policy with workers workers.sync_weights() # Sync TD Models and normalization stats with workers sync_ensemble(workers) sync_stats(workers) # Dropping metrics from the first iteration episodes, to_be_collected = collect_episodes( workers.local_worker(), workers.remote_workers(), [], timeout_seconds=9999) # Metrics Collector metric_collect = CollectMetrics( workers, min_history=0, timeout_seconds=config["collect_metrics_timeout"]) inner_steps = config["inner_adaptation_steps"] def inner_adaptation_steps(itr): buf = [] split = [] metrics = {} for samples in itr: print("Collecting Samples, Inner Adaptation {}".format(len(split))) # Processing Samples (Standardize Advantages) samples, split_lst = post_process_samples(samples, config) buf.extend(samples) split.append(split_lst) adapt_iter = len(split) - 1 prefix = "DynaTrajInner_" + str(adapt_iter) metrics = post_process_metrics(prefix, workers, metrics) if len(split) > inner_steps: out = SampleBatch.concat_samples(buf) out["split"] = np.array(split) buf = [] split = [] yield out, metrics metrics = {} else: inner_adaptation(workers, samples) # Iterator for Inner Adaptation Data gathering (from pre->post adaptation) rollouts = from_actors(workers.remote_workers()) rollouts = rollouts.batch_across_shards() rollouts = rollouts.transform(inner_adaptation_steps) # Metaupdate Step with outer combine loop for multiple MAML iterations train_op = rollouts.combine( MetaUpdate(workers, config["num_maml_steps"], config["maml_optimizer_steps"], metric_collect)) return train_op
def ParallelRollouts(workers: WorkerSet, *, mode="bulk_sync", num_async=1) -> LocalIterator[SampleBatch]: """Operator to collect experiences in parallel from rollout workers. If there are no remote workers, experiences will be collected serially from the local worker instance instead. Arguments: workers (WorkerSet): set of rollout workers to use. mode (str): One of {'async', 'bulk_sync', 'raw'}. - In 'async' mode, batches are returned as soon as they are computed by rollout workers with no order guarantees. - In 'bulk_sync' mode, we collect one batch from each worker and concatenate them together into a large batch to return. - In 'raw' mode, the ParallelIterator object is returned directly and the caller is responsible for implementing gather and updating the timesteps counter. num_async (int): In async mode, the max number of async requests in flight per actor. Returns: A local iterator over experiences collected in parallel. Examples: >>> rollouts = ParallelRollouts(workers, mode="async") >>> batch = next(rollouts) >>> print(batch.count) 50 # config.rollout_fragment_length >>> rollouts = ParallelRollouts(workers, mode="bulk_sync") >>> batch = next(rollouts) >>> print(batch.count) 200 # config.rollout_fragment_length * config.num_workers Updates the STEPS_SAMPLED_COUNTER counter in the local iterator context. """ # Ensure workers are initially in sync. workers.sync_weights() def report_timesteps(batch): metrics = _get_shared_metrics() metrics.counters[STEPS_SAMPLED_COUNTER] += batch.count return batch if not workers.remote_workers(): # Handle the serial sampling case. def sampler(_): while True: yield workers.local_worker().sample() return (LocalIterator(sampler, SharedMetrics()).for_each(report_timesteps)) # Create a parallel iterator over generated experiences. rollouts = from_actors(workers.remote_workers()) if mode == "bulk_sync": return rollouts \ .batch_across_shards() \ .for_each(lambda batches: SampleBatch.concat_samples(batches)) \ .for_each(report_timesteps) elif mode == "async": return rollouts.gather_async( num_async=num_async).for_each(report_timesteps) elif mode == "raw": return rollouts else: raise ValueError("mode must be one of 'bulk_sync', 'async', 'raw', " "got '{}'".format(mode))
def execution_plan(workers: WorkerSet, config: AlgorithmConfigDict, **kwargs) -> LocalIterator[dict]: assert ( len(kwargs) == 0 ), "MBMPO execution_plan does NOT take any additional parameters" # Train TD Models on the driver. workers.local_worker().foreach_policy(fit_dynamics) # Sync driver's policy with workers. workers.sync_weights() # Sync TD Models and normalization stats with workers sync_ensemble(workers) sync_stats(workers) # Dropping metrics from the first iteration _, _ = collect_episodes(workers.local_worker(), workers.remote_workers(), [], timeout_seconds=9999) # Metrics Collector. metric_collect = CollectMetrics( workers, min_history=0, timeout_seconds=config["metrics_episode_collection_timeout_s"], ) num_inner_steps = config["inner_adaptation_steps"] def inner_adaptation_steps(itr): buf = [] split = [] metrics = {} for samples in itr: print("Collecting Samples, Inner Adaptation {}".format( len(split))) # Processing Samples (Standardize Advantages) samples, split_lst = post_process_samples(samples, config) buf.extend(samples) split.append(split_lst) adapt_iter = len(split) - 1 prefix = "DynaTrajInner_" + str(adapt_iter) metrics = post_process_metrics(prefix, workers, metrics) if len(split) > num_inner_steps: out = SampleBatch.concat_samples(buf) out["split"] = np.array(split) buf = [] split = [] yield out, metrics metrics = {} else: inner_adaptation(workers, samples) # Iterator for Inner Adaptation Data gathering (from pre->post # adaptation). rollouts = from_actors(workers.remote_workers()) rollouts = rollouts.batch_across_shards() rollouts = rollouts.transform(inner_adaptation_steps) # Meta update step with outer combine loop for multiple MAML # iterations. train_op = rollouts.combine( MetaUpdate( workers, config["num_maml_steps"], config["maml_optimizer_steps"], metric_collect, )) return train_op
def execution_plan(workers: WorkerSet, config: TrainerConfigDict) -> LocalIterator[dict]: """Execution plan of the PPO algorithm. Defines the distributed dataflow. Args: workers (WorkerSet): The WorkerSet for training the Polic(y/ies) of the Trainer. config (TrainerConfigDict): The trainer's configuration dict. Returns: LocalIterator[dict]: The Policy class to use with PPOTrainer. If None, use `default_policy` provided in build_trainer(). """ # Train TD Models on the driver. workers.local_worker().foreach_policy(fit_dynamics) # Sync driver's policy with workers. workers.sync_weights() # Sync TD Models and normalization stats with workers sync_ensemble(workers) sync_stats(workers) # Dropping metrics from the first iteration _, _ = collect_episodes(workers.local_worker(), workers.remote_workers(), [], timeout_seconds=9999) # Metrics Collector. metric_collect = CollectMetrics( workers, min_history=0, timeout_seconds=config["collect_metrics_timeout"]) num_inner_steps = config["inner_adaptation_steps"] def inner_adaptation_steps(itr): buf = [] split = [] metrics = {} for samples in itr: print("Collecting Samples, Inner Adaptation {}".format(len(split))) # Processing Samples (Standardize Advantages) samples, split_lst = post_process_samples(samples, config) buf.extend(samples) split.append(split_lst) adapt_iter = len(split) - 1 prefix = "DynaTrajInner_" + str(adapt_iter) metrics = post_process_metrics(prefix, workers, metrics) if len(split) > num_inner_steps: out = SampleBatch.concat_samples(buf) out["split"] = np.array(split) buf = [] split = [] yield out, metrics metrics = {} else: inner_adaptation(workers, samples) # Iterator for Inner Adaptation Data gathering (from pre->post adaptation). rollouts = from_actors(workers.remote_workers()) rollouts = rollouts.batch_across_shards() rollouts = rollouts.transform(inner_adaptation_steps) # Meta update step with outer combine loop for multiple MAML iterations. train_op = rollouts.combine( MetaUpdate(workers, config["num_maml_steps"], config["maml_optimizer_steps"], metric_collect)) return train_op
def ParallelRollouts(workers: WorkerSet, *, mode="bulk_sync", num_async=1) -> LocalIterator[SampleBatch]: """Operator to collect experiences in parallel from rollout workers. If there are no remote workers, experiences will be collected serially from the local worker instance instead. Args: workers (WorkerSet): set of rollout workers to use. mode (str): One of 'async', 'bulk_sync', 'raw'. In 'async' mode, batches are returned as soon as they are computed by rollout workers with no order guarantees. In 'bulk_sync' mode, we collect one batch from each worker and concatenate them together into a large batch to return. In 'raw' mode, the ParallelIterator object is returned directly and the caller is responsible for implementing gather and updating the timesteps counter. num_async (int): In async mode, the max number of async requests in flight per actor. Returns: A local iterator over experiences collected in parallel. Examples: >>> from ray.rllib.execution import ParallelRollouts >>> workers = ... # doctest: +SKIP >>> rollouts = ParallelRollouts(workers, mode="async") # doctest: +SKIP >>> batch = next(rollouts) # doctest: +SKIP >>> print(batch.count) # doctest: +SKIP 50 # config.rollout_fragment_length >>> rollouts = ParallelRollouts(workers, mode="bulk_sync") # doctest: +SKIP >>> batch = next(rollouts) # doctest: +SKIP >>> print(batch.count) # doctest: +SKIP 200 # config.rollout_fragment_length * config.num_workers Updates the STEPS_SAMPLED_COUNTER counter in the local iterator context. """ # Ensure workers are initially in sync. workers.sync_weights() def report_timesteps(batch): metrics = _get_shared_metrics() metrics.counters[STEPS_SAMPLED_COUNTER] += batch.count if isinstance(batch, MultiAgentBatch): metrics.counters[AGENT_STEPS_SAMPLED_COUNTER] += batch.agent_steps( ) else: metrics.counters[AGENT_STEPS_SAMPLED_COUNTER] += batch.count return batch if not workers.remote_workers(): # Handle the `num_workers=0` case, in which the local worker # has to do sampling as well. return LocalIterator( lambda timeout: workers.local_worker().item_generator, SharedMetrics()).for_each(report_timesteps) # Create a parallel iterator over generated experiences. rollouts = from_actors(workers.remote_workers()) if mode == "bulk_sync": return (rollouts.batch_across_shards().for_each( lambda batches: SampleBatch.concat_samples(batches)).for_each( report_timesteps)) elif mode == "async": return rollouts.gather_async( num_async=num_async).for_each(report_timesteps) elif mode == "raw": return rollouts else: raise ValueError( "mode must be one of 'bulk_sync', 'async', 'raw', got '{}'".format( mode))