def gather_experiences_tree_aggregation(workers, config): """Tree aggregation version of gather_experiences_directly().""" rollouts = ParallelRollouts(workers, mode="raw") # Divide up the workers between aggregators. worker_assignments = [[] for _ in range(config["num_aggregation_workers"])] i = 0 for w in range(len(workers.remote_workers())): worker_assignments[i].append(w) i += 1 i %= len(worker_assignments) logger.info("Worker assignments: {}".format(worker_assignments)) # Create parallel iterators that represent each aggregation group. rollout_groups: List["ParallelIterator[SampleBatchType]"] = [ rollouts.select_shards(assigned) for assigned in worker_assignments ] # This spawns |num_aggregation_workers| intermediate actors that aggregate # experiences in parallel. We force colocation on the same node to maximize # data bandwidth between them and the driver. train_batches = from_actors([ create_colocated(Aggregator, [config, g], 1)[0] for g in rollout_groups ]) # TODO(ekl) properly account for replay. def record_steps_sampled(batch): metrics = _get_shared_metrics() metrics.counters[STEPS_SAMPLED_COUNTER] += batch.count return batch return train_batches.gather_async().for_each(record_steps_sampled)
def gather_experiences_tree_aggregation(workers: WorkerSet, config: Dict) -> "LocalIterator[Any]": """Tree aggregation version of gather_experiences_directly().""" rollouts = ParallelRollouts(workers, mode="raw") # Divide up the workers between aggregators. worker_assignments = [[] for _ in range(config["num_aggregation_workers"])] i = 0 for worker_idx in range(len(workers.remote_workers())): worker_assignments[i].append(worker_idx) i += 1 i %= len(worker_assignments) logger.info("Worker assignments: {}".format(worker_assignments)) # Create parallel iterators that represent each aggregation group. rollout_groups: List["ParallelIterator[SampleBatchType]"] = [ rollouts.select_shards(assigned) for assigned in worker_assignments ] # This spawns |num_aggregation_workers| intermediate actors that aggregate # experiences in parallel. We force colocation on the same node (localhost) # to maximize data bandwidth between them and the driver. localhost = platform.node() assert localhost != "", ("ERROR: Cannot determine local node name! " "`platform.node()` returned empty string.") all_co_located = create_colocated_actors( actor_specs=[ # (class, args, kwargs={}, count=1) (Aggregator, [config, g], {}, 1) for g in rollout_groups ], node=localhost, ) # Use the first ([0]) of each created group (each group only has one # actor: count=1). train_batches = from_actors([group[0] for group in all_co_located]) # TODO(ekl) properly account for replay. def record_steps_sampled(batch): metrics = _get_shared_metrics() metrics.counters[STEPS_SAMPLED_COUNTER] += batch.count if isinstance(batch, MultiAgentBatch): metrics.counters[AGENT_STEPS_SAMPLED_COUNTER] += batch.agent_steps( ) else: metrics.counters[AGENT_STEPS_SAMPLED_COUNTER] += batch.count return batch return train_batches.gather_async().for_each(record_steps_sampled)