Пример #1
0
def execution_plan(workers, config, **kwargs):
    assert len(kwargs) == 0, (
        "Alpha zero execution_plan does NOT take any additional parameters")

    rollouts = ParallelRollouts(workers, mode="bulk_sync")

    if config["simple_optimizer"]:
        train_op = rollouts.combine(
            ConcatBatches(
                min_batch_size=config["train_batch_size"],
                count_steps_by=config["multiagent"]["count_steps_by"],
            )).for_each(
                TrainOneStep(workers, num_sgd_iter=config["num_sgd_iter"]))
    else:
        replay_buffer = SimpleReplayBuffer(config["buffer_size"])

        store_op = rollouts \
            .for_each(StoreToReplayBuffer(local_buffer=replay_buffer))

        replay_op = Replay(local_buffer=replay_buffer) \
            .filter(WaitUntilTimestepsElapsed(config["learning_starts"])) \
            .combine(
            ConcatBatches(
                min_batch_size=config["train_batch_size"],
                count_steps_by=config["multiagent"]["count_steps_by"],
            )) \
            .for_each(TrainOneStep(
                workers, num_sgd_iter=config["num_sgd_iter"]))

        train_op = Concurrently(
            [store_op, replay_op], mode="round_robin", output_indexes=[1])

    return StandardMetricsReporting(train_op, workers, config)
Пример #2
0
def execution_plan(workers, config):
    rollouts = ParallelRollouts(workers, mode="bulk_sync")

    if config["simple_optimizer"]:
        train_op = rollouts \
            .combine(ConcatBatches(
                min_batch_size=config["train_batch_size"])) \
            .for_each(TrainOneStep(
                workers, num_sgd_iter=config["num_sgd_iter"]))
    else:
        replay_buffer = SimpleReplayBuffer(config["buffer_size"])

        store_op = rollouts \
            .for_each(StoreToReplayBuffer(local_buffer=replay_buffer))

        replay_op = Replay(local_buffer=replay_buffer) \
            .filter(WaitUntilTimestepsElapsed(config["learning_starts"])) \
            .combine(
                ConcatBatches(min_batch_size=config["train_batch_size"])) \
            .for_each(TrainOneStep(
                workers, num_sgd_iter=config["num_sgd_iter"]))

        train_op = Concurrently(
            [store_op, replay_op], mode="round_robin", output_indexes=[1])

    return StandardMetricsReporting(train_op, workers, config)
Пример #3
0
    def execution_plan(
        workers: WorkerSet, config: TrainerConfigDict, **kwargs
    ) -> LocalIterator[dict]:
        assert (
            len(kwargs) == 0
        ), "QMIX execution_plan does NOT take any additional parameters"

        rollouts = ParallelRollouts(workers, mode="bulk_sync")
        replay_buffer = SimpleReplayBuffer(config["buffer_size"])

        store_op = rollouts.for_each(StoreToReplayBuffer(local_buffer=replay_buffer))

        train_op = (
            Replay(local_buffer=replay_buffer)
            .combine(
                ConcatBatches(
                    min_batch_size=config["train_batch_size"],
                    count_steps_by=config["multiagent"]["count_steps_by"],
                )
            )
            .for_each(TrainOneStep(workers))
            .for_each(
                UpdateTargetNetwork(workers, config["target_network_update_freq"])
            )
        )

        merged_op = Concurrently(
            [store_op, train_op], mode="round_robin", output_indexes=[1]
        )

        return StandardMetricsReporting(merged_op, workers, config)
Пример #4
0
def execution_plan(workers: WorkerSet,
                   config: TrainerConfigDict) -> LocalIterator[dict]:
    """Execution plan of the MARWIL/BC algorithm. Defines the distributed
    dataflow.

    Args:
        workers (WorkerSet): The WorkerSet for training the Polic(y/ies)
            of the Trainer.
        config (TrainerConfigDict): The trainer's configuration dict.

    Returns:
        LocalIterator[dict]: A local iterator over training metrics.
    """
    rollouts = ParallelRollouts(workers, mode="bulk_sync")
    replay_buffer = SimpleReplayBuffer(config["replay_buffer_size"])

    store_op = rollouts \
        .for_each(StoreToReplayBuffer(local_buffer=replay_buffer))

    replay_op = Replay(local_buffer=replay_buffer) \
        .combine(
            ConcatBatches(
                min_batch_size=config["train_batch_size"],
                count_steps_by=config["multiagent"]["count_steps_by"],
            )) \
        .for_each(TrainOneStep(workers))

    train_op = Concurrently([store_op, replay_op],
                            mode="round_robin",
                            output_indexes=[1])

    return StandardMetricsReporting(train_op, workers, config)
Пример #5
0
def execution_plan(workers, config):
    rollouts = ParallelRollouts(workers, mode="bulk_sync")
    replay_buffer = SimpleReplayBuffer(config["replay_buffer_size"])

    store_op = rollouts \
        .for_each(StoreToReplayBuffer(local_buffer=replay_buffer))

    replay_op = Replay(local_buffer=replay_buffer) \
        .combine(
            ConcatBatches(min_batch_size=config["train_batch_size"])) \
        .for_each(TrainOneStep(workers))

    train_op = Concurrently([store_op, replay_op],
                            mode="round_robin",
                            output_indexes=[1])

    return StandardMetricsReporting(train_op, workers, config)
Пример #6
0
def execution_plan(workers, config):
    rollouts = ParallelRollouts(workers, mode="bulk_sync")
    replay_buffer = SimpleReplayBuffer(config["buffer_size"])

    store_op = rollouts \
        .for_each(StoreToReplayBuffer(local_buffer=replay_buffer))

    train_op = Replay(local_buffer=replay_buffer) \
        .combine(
        ConcatBatches(
            min_batch_size=config["train_batch_size"],
            count_steps_by=config["multiagent"]["count_steps_by"]
        )) \
        .for_each(TrainOneStep(workers)) \
        .for_each(UpdateTargetNetwork(
            workers, config["target_network_update_freq"]))

    merged_op = Concurrently([store_op, train_op],
                             mode="round_robin",
                             output_indexes=[1])

    return StandardMetricsReporting(merged_op, workers, config)
Пример #7
0
 def new_buffer():
     return SimpleReplayBuffer(num_slots=capacity)