Пример #1
0
def execution_plan(workers, config):
    local_replay_buffer = ReplayBuffer(config["buffer_size"])
    rollouts = ParallelRollouts(workers, mode="bulk_sync")

    # We execute the following steps concurrently:
    # (1) Generate rollouts and store them in our local replay buffer. Calling
    # next() on store_op drives this.
    store_op = rollouts.for_each(StoreToReplayBuffer(local_replay_buffer))

    # (2) Read and train on experiences from the replay buffer. Every batch
    # returned from the LocalReplay() iterator is passed to TrainOneStep to
    # take a SGD step, and then we decide whether to update the target network.
    replay_op = LocalReplay(local_replay_buffer, config["train_batch_size"]) \
        .for_each(TrainOneStep(workers)) \
        .for_each(UpdateTargetNetwork(
            workers, config["target_network_update_freq"]))

    # Alternate deterministically between (1) and (2).
    train_op = Concurrently([store_op, replay_op], mode="round_robin")

    return StandardMetricsReporting(train_op, workers, config)
Пример #2
0
def execution_plan(workers, config):
    # Create a number of replay buffer actors.
    # TODO(ekl) support batch replay options
    num_replay_buffer_shards = config["optimizer"]["num_replay_buffer_shards"]
    replay_actors = create_colocated(ReplayActor, [
        num_replay_buffer_shards,
        config["learning_starts"],
        config["buffer_size"],
        config["train_batch_size"],
        config["prioritized_replay_alpha"],
        config["prioritized_replay_beta"],
        config["prioritized_replay_eps"],
    ], num_replay_buffer_shards)

    # Update experience priorities post learning.
    def update_prio_and_stats(item):
        actor, prio_dict, count = item
        actor.update_priorities.remote(prio_dict)
        metrics = LocalIterator.get_metrics()
        metrics.counters[STEPS_TRAINED_COUNTER] += count
        metrics.timers["learner_dequeue"] = learner_thread.queue_timer
        metrics.timers["learner_grad"] = learner_thread.grad_timer
        metrics.timers["learner_overall"] = learner_thread.overall_timer

    # Update worker weights as they finish generating experiences.
    class UpdateWorkerWeights:
        def __init__(self, learner_thread, workers, max_weight_sync_delay):
            self.learner_thread = learner_thread
            self.workers = workers
            self.steps_since_update = collections.defaultdict(int)
            self.max_weight_sync_delay = max_weight_sync_delay
            self.weights = None

        def __call__(self, item):
            actor, batch = item
            self.steps_since_update[actor] += batch.count
            if self.steps_since_update[actor] >= self.max_weight_sync_delay:
                # Note that it's important to pull new weights once
                # updated to avoid excessive correlation between actors.
                if self.weights is None or self.learner_thread.weights_updated:
                    self.learner_thread.weights_updated = False
                    self.weights = ray.put(
                        self.workers.local_worker().get_weights())
                actor.set_weights.remote(self.weights)
                self.steps_since_update[actor] = 0
                # Update metrics.
                metrics = LocalIterator.get_metrics()
                metrics.counters["num_weight_syncs"] += 1

    # Start the learner thread.
    learner_thread = LearnerThread(workers.local_worker())
    learner_thread.start()

    # We execute the following steps concurrently:
    # (1) Generate rollouts and store them in our replay buffer actors. Update
    # the weights of the worker that generated the batch.
    rollouts = ParallelRollouts(workers, mode="async", async_queue_depth=2)
    store_op = rollouts \
        .for_each(StoreToReplayActors(replay_actors)) \
        .zip_with_source_actor() \
        .for_each(UpdateWorkerWeights(
            learner_thread, workers,
            max_weight_sync_delay=config["optimizer"]["max_weight_sync_delay"])
        )

    # (2) Read experiences from the replay buffer actors and send to the
    # learner thread via its in-queue.
    replay_op = ParallelReplay(replay_actors, async_queue_depth=4) \
        .zip_with_source_actor() \
        .for_each(Enqueue(learner_thread.inqueue))

    # (3) Get priorities back from learner thread and apply them to the
    # replay buffer actors.
    update_op = Dequeue(
            learner_thread.outqueue, check=learner_thread.is_alive) \
        .for_each(update_prio_and_stats) \
        .for_each(UpdateTargetNetwork(
            workers, config["target_network_update_freq"],
            by_steps_trained=True))

    # Execute (1), (2), (3) asynchronously as fast as possible.
    merged_op = Concurrently([store_op, replay_op, update_op], mode="async")

    return StandardMetricsReporting(merged_op, workers, config)