Ejemplo n.º 1
0
def gather_experiences_tree_aggregation(workers, config):
    """Tree aggregation version of gather_experiences_directly()."""

    rollouts = ParallelRollouts(workers, mode="raw")

    # Divide up the workers between aggregators.
    worker_assignments = [[] for _ in range(config["num_aggregation_workers"])]
    i = 0
    for w in range(len(workers.remote_workers())):
        worker_assignments[i].append(w)
        i += 1
        i %= len(worker_assignments)
    logger.info("Worker assignments: {}".format(worker_assignments))

    # Create parallel iterators that represent each aggregation group.
    rollout_groups: List["ParallelIterator[SampleBatchType]"] = [
        rollouts.select_shards(assigned) for assigned in worker_assignments
    ]

    # This spawns |num_aggregation_workers| intermediate actors that aggregate
    # experiences in parallel. We force colocation on the same node to maximize
    # data bandwidth between them and the driver.
    train_batches = from_actors([
        create_colocated(Aggregator, [config, g], 1)[0] for g in rollout_groups
    ])

    # TODO(ekl) properly account for replay.
    def record_steps_sampled(batch):
        metrics = _get_shared_metrics()
        metrics.counters[STEPS_SAMPLED_COUNTER] += batch.count
        return batch

    return train_batches.gather_async().for_each(record_steps_sampled)
Ejemplo n.º 2
0
def gather_experiences_tree_aggregation(workers: WorkerSet,
                                        config: Dict) -> "LocalIterator[Any]":
    """Tree aggregation version of gather_experiences_directly()."""

    rollouts = ParallelRollouts(workers, mode="raw")

    # Divide up the workers between aggregators.
    worker_assignments = [[] for _ in range(config["num_aggregation_workers"])]
    i = 0
    for worker_idx in range(len(workers.remote_workers())):
        worker_assignments[i].append(worker_idx)
        i += 1
        i %= len(worker_assignments)
    logger.info("Worker assignments: {}".format(worker_assignments))

    # Create parallel iterators that represent each aggregation group.
    rollout_groups: List["ParallelIterator[SampleBatchType]"] = [
        rollouts.select_shards(assigned) for assigned in worker_assignments
    ]

    # This spawns |num_aggregation_workers| intermediate actors that aggregate
    # experiences in parallel. We force colocation on the same node (localhost)
    # to maximize data bandwidth between them and the driver.
    localhost = platform.node()
    assert localhost != "", ("ERROR: Cannot determine local node name! "
                             "`platform.node()` returned empty string.")
    all_co_located = create_colocated_actors(
        actor_specs=[
            # (class, args, kwargs={}, count=1)
            (Aggregator, [config, g], {}, 1) for g in rollout_groups
        ],
        node=localhost,
    )

    # Use the first ([0]) of each created group (each group only has one
    # actor: count=1).
    train_batches = from_actors([group[0] for group in all_co_located])

    # TODO(ekl) properly account for replay.
    def record_steps_sampled(batch):
        metrics = _get_shared_metrics()
        metrics.counters[STEPS_SAMPLED_COUNTER] += batch.count
        if isinstance(batch, MultiAgentBatch):
            metrics.counters[AGENT_STEPS_SAMPLED_COUNTER] += batch.agent_steps(
            )
        else:
            metrics.counters[AGENT_STEPS_SAMPLED_COUNTER] += batch.count
        return batch

    return train_batches.gather_async().for_each(record_steps_sampled)